Skip to content

Commit aa84fb5

Browse files
[TorchToLinalg] : Lower count_nonzero.dim_IntList to Linalg-on-Tensors (#4146)
- Lowered count_nonzero.dim_IntList op to Linalg-on-Tensors - Decomposed aten.count_nonzero using aten.ne.Scalar and aten.sum.dim_IntList - Added tests to torch-mlir/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
1 parent 64a9be7 commit aa84fb5

File tree

9 files changed

+199
-0
lines changed

9 files changed

+199
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9927,6 +9927,31 @@ def Torch_AtenCountNonzeroOp : Torch_Op<"aten.count_nonzero", [
99279927
let hasVerifier = 1;
99289928
}
99299929

9930+
def Torch_AtenCountNonzeroDimIntListOp : Torch_Op<"aten.count_nonzero.dim_IntList", [
9931+
AllowsTypeRefinement,
9932+
HasValueSemantics,
9933+
ReadOnly
9934+
]> {
9935+
let summary = "Generated op for `aten::count_nonzero.dim_IntList : (Tensor, int[]) -> (Tensor)`";
9936+
let arguments = (ins
9937+
AnyTorchTensorType:$self,
9938+
AnyTorchListOfTorchIntType:$dim
9939+
);
9940+
let results = (outs
9941+
AnyTorchOptionalTensorType:$result
9942+
);
9943+
let hasCustomAssemblyFormat = 1;
9944+
let extraClassDefinition = [{
9945+
ParseResult AtenCountNonzeroDimIntListOp::parse(OpAsmParser &parser, OperationState &result) {
9946+
return parseDefaultTorchOp(parser, result, 2, 1);
9947+
}
9948+
void AtenCountNonzeroDimIntListOp::print(OpAsmPrinter &printer) {
9949+
printDefaultTorchOp(printer, *this, 2, 1);
9950+
}
9951+
}];
9952+
let hasVerifier = 1;
9953+
}
9954+
99309955
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
99319956
AllowsTypeRefinement,
99329957
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6060,6 +6060,32 @@ LogicalResult AtenCountNonzeroOp::verify() {
60606060
return success();
60616061
}
60626062

6063+
//===----------------------------------------------------------------------===//
6064+
// AtenCountNonzeroDimIntListOp
6065+
//===----------------------------------------------------------------------===//
6066+
6067+
LogicalResult AtenCountNonzeroDimIntListOp::verify() {
6068+
6069+
auto selfType = cast<BaseTensorType>(getSelf().getType());
6070+
6071+
if (!selfType.hasDtype() || !selfType.hasSizes())
6072+
return success();
6073+
6074+
SmallVector<int64_t> dims;
6075+
if (!matchPattern(getDim(), m_TorchListOfConstantInts(dims)))
6076+
return emitOpError("expected dim to be constructed from list construct");
6077+
6078+
int64_t selfRank = selfType.getSizes().size();
6079+
6080+
for (auto d : dims) {
6081+
if (d >= selfRank || d < -selfRank)
6082+
return emitOpError("expected to be in [ ")
6083+
<< -selfRank << " , " << selfRank - 1 << " ], but got dim = " << d;
6084+
}
6085+
6086+
return success();
6087+
}
6088+
60636089
//===----------------------------------------------------------------------===//
60646090
// OnnxVariantRotaryEmbeddingOp
60656091
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9234,6 +9234,40 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
92349234
" }\n"
92359235
" return %1 : !torch.list<int>\n"
92369236
" }\n"
9237+
" func.func @\"__torch_mlir_shape_fn.aten.count_nonzero.dim_IntList\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9238+
" %false = torch.constant.bool false\n"
9239+
" %none = torch.constant.none\n"
9240+
" %str = torch.constant.str \"AssertionError: \"\n"
9241+
" %true = torch.constant.bool true\n"
9242+
" %int0 = torch.constant.int 0\n"
9243+
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
9244+
" torch.prim.Loop %0, %true, init() {\n"
9245+
" ^bb0(%arg2: !torch.int):\n"
9246+
" %4 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
9247+
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9248+
" %6 = torch.aten.neg.int %5 : !torch.int -> !torch.int\n"
9249+
" %7 = torch.aten.lt.int %4, %6 : !torch.int, !torch.int -> !torch.bool\n"
9250+
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
9251+
" torch.prim.If.yield %true : !torch.bool\n"
9252+
" } else {\n"
9253+
" %10 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9254+
" %11 = torch.aten.ge.int %4, %10 : !torch.int, !torch.int -> !torch.bool\n"
9255+
" torch.prim.If.yield %11 : !torch.bool\n"
9256+
" }\n"
9257+
" %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n"
9258+
" torch.prim.If %9 -> () {\n"
9259+
" torch.prim.If.yield\n"
9260+
" } else {\n"
9261+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9262+
" torch.prim.If.yield\n"
9263+
" }\n"
9264+
" torch.prim.Loop.condition %true, iter()\n"
9265+
" } : (!torch.int, !torch.bool) -> ()\n"
9266+
" %1 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
9267+
" %2 = torch.derefine %int0 : !torch.int to !torch.any\n"
9268+
" %3 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %1, %false, %2) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
9269+
" return %3 : !torch.list<int>\n"
9270+
" }\n"
92379271
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
92389272
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
92399273
" return %0 : !torch.list<int>\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6702,6 +6702,42 @@ class DecomposeAtenCountNonzeroOp
67026702
};
67036703
} // namespace
67046704

6705+
// Decompose aten.count_nonzero to aten.ne.Scalar and aten.sum.dim_IntList
6706+
namespace {
6707+
class DecomposeAtenCountNonzeroDimIntListOp
6708+
: public OpRewritePattern<AtenCountNonzeroDimIntListOp> {
6709+
public:
6710+
using OpRewritePattern::OpRewritePattern;
6711+
LogicalResult matchAndRewrite(AtenCountNonzeroDimIntListOp op,
6712+
PatternRewriter &rewriter) const override {
6713+
Value dimList = op.getDim();
6714+
if (isa<Torch::NoneType>(dimList.getType()))
6715+
return rewriter.notifyMatchFailure(
6716+
op, "expected `dim` to be constructed from list");
6717+
6718+
SmallVector<int64_t> dimIntListElem;
6719+
if (!matchPattern(dimList, m_TorchListOfConstantInts(dimIntListElem)))
6720+
return rewriter.notifyMatchFailure(
6721+
op, "expected `dim` to be constructed from list of integers");
6722+
6723+
Location loc = op.getLoc();
6724+
Value self = op.getSelf();
6725+
BaseTensorType inputType = cast<BaseTensorType>(self.getType());
6726+
auto inpBoolTy = inputType.getWithSizesAndDtype(inputType.getSizes(),
6727+
rewriter.getI1Type());
6728+
Value cstZero =
6729+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
6730+
Value nonZeroMask =
6731+
rewriter.create<AtenNeScalarOp>(loc, inpBoolTy, self, cstZero);
6732+
Value none = rewriter.create<ConstantNoneOp>(loc);
6733+
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
6734+
rewriter.replaceOpWithNewOp<AtenSumDimIntListOp>(
6735+
op, op.getResult().getType(), nonZeroMask, dimList, cstFalse, none);
6736+
return success();
6737+
}
6738+
};
6739+
} // namespace
6740+
67056741
// Decompose aten.std.correction to sqrt(var.correction(x))
67066742
namespace {
67076743
class DecomposeAtenStdCorrectionOp
@@ -12064,6 +12100,8 @@ class DecomposeComplexOpsPass
1206412100
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
1206512101
addPatternIfTargetOpIsIllegal<DecomposeAtenRot90Op>(patterns);
1206612102
addPatternIfTargetOpIsIllegal<DecomposeAtenCountNonzeroOp>(patterns);
12103+
addPatternIfTargetOpIsIllegal<DecomposeAtenCountNonzeroDimIntListOp>(
12104+
patterns);
1206712105
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
1206812106
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
1206912107
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
414414
target.addIllegalOp<AtenRenormOp>();
415415
target.addIllegalOp<AtenRot90Op>();
416416
target.addIllegalOp<AtenCountNonzeroOp>();
417+
target.addIllegalOp<AtenCountNonzeroDimIntListOp>();
417418
target.addIllegalOp<AtenLinalgCrossOp>();
418419
target.addIllegalOp<Aten_LinalgDetOp>();
419420
target.addIllegalOp<AtenLinalgSlogdetOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2862,6 +2862,9 @@
28622862
"CountNonzeroModuleBool_Basic",
28632863
"CountNonzeroModuleF32_basic",
28642864
"CountNonzeroModuleI64_basic",
2865+
"CountNonzeroDimIntListModuleBool_Basic",
2866+
"CountNonzeroDimIntListModuleI64_basic",
2867+
"CountNonzeroDimIntListModuleF32_basic",
28652868
"Deg2radModule_basic",
28662869
"DivFloatModule_basic",
28672870
"DivIntModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,23 @@ def aten〇count_nonzero〡shape(self: List[int], dim: Optional[int] = None) ->
14531453
assert not (dim < -len(self) or dim >= len(self))
14541454
return upstream_shape_functions.argmax(self, dim)
14551455

1456+
@check_shape_function([
1457+
Invocation(TensorOfShape(2, 3, 4), dim = []), # Basic case.
1458+
Invocation(TensorOfShape(2, 3, 4), dim=[1]), # Test explicit dim.
1459+
Invocation(TensorOfShape(2, 3, 4), dim=[-1]), # Test explicit dim.
1460+
Invocation(TensorOfShape(2, 3, 4), dim=[0,1]), # Test explicit dim.
1461+
Invocation(TensorOfShape(2, 3, 4), dim=[-3]), # Test explicit dim(negative).
1462+
Invocation(TensorOfShape(2, 3, 4), dim=[2]), # Maximum valid dim.
1463+
ErrorInvocation(TensorOfShape(2, 3, 4), dim=[-4]), # Test dim out of bound.
1464+
ErrorInvocation(TensorOfShape(2, 3, 4), dim=[1,-4]), # Test dim out of bound.
1465+
ErrorInvocation(TensorOfShape(2, 3, 4), dim=[3]), # Test dim out of bound.
1466+
])
1467+
def aten〇count_nonzero〇dim_IntList〡shape(self: List[int], dim: List[int]) -> List[int]:
1468+
if dim is None: return []
1469+
for d in dim:
1470+
assert not (d < -len(self) or d >= len(self))
1471+
return upstream_shape_functions.sum_mean_dim(self, dim, keep_dim=False, dt=0)
1472+
14561473
def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
14571474
return upstream_shape_functions.unary(self)
14581475

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,10 @@ def emit_with_mutating_variants(key, **kwargs):
788788
emit("aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)")
789789
emit("aten::rot90 : (Tensor, int, int[]) -> (Tensor)", has_verifier=True)
790790
emit("aten::count_nonzero : (Tensor, int?) -> (Tensor)", has_verifier=True)
791+
emit(
792+
"aten::count_nonzero.dim_IntList : (Tensor, int[]) -> (Tensor)",
793+
has_verifier=True,
794+
)
791795

792796
# Misc tensor ops.
793797
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,3 +2546,54 @@ def forward(self, x):
25462546
@register_test_case(module_factory=lambda: CountNonzeroModuleBool())
25472547
def CountNonzeroModuleBool_Basic(module, tu: TestUtils):
25482548
module.forward(tu.randint(2, 3, 4, low=0, high=2).to(torch.bool))
2549+
2550+
2551+
# ==============================================================================
2552+
2553+
2554+
class CountNonzeroDimIntListModuleF32(torch.nn.Module):
2555+
def __init__(self):
2556+
super().__init__()
2557+
self.dim = [-1, 1]
2558+
2559+
@export
2560+
@annotate_args([None, ([2, 3, 4], torch.float32, True)])
2561+
def forward(self, x):
2562+
return torch.ops.aten.count_nonzero(x, self.dim)
2563+
2564+
2565+
@register_test_case(module_factory=lambda: CountNonzeroDimIntListModuleF32())
2566+
def CountNonzeroDimIntListModuleF32_basic(module, tu: TestUtils):
2567+
module.forward(tu.rand(2, 3, 4))
2568+
2569+
2570+
class CountNonzeroDimIntListModuleI64(torch.nn.Module):
2571+
def __init__(self):
2572+
super().__init__()
2573+
self.dim = [-2, -0, -1]
2574+
2575+
@export
2576+
@annotate_args([None, ([2, 3, 4], torch.int64, True)])
2577+
def forward(self, x):
2578+
return torch.ops.aten.count_nonzero(x, self.dim)
2579+
2580+
2581+
@register_test_case(module_factory=lambda: CountNonzeroDimIntListModuleI64())
2582+
def CountNonzeroDimIntListModuleI64_basic(module, tu: TestUtils):
2583+
module.forward(tu.randint(2, 3, 4))
2584+
2585+
2586+
class CountNonzeroDimIntListModuleBool(torch.nn.Module):
2587+
def __init__(self):
2588+
super().__init__()
2589+
self.dim = [1]
2590+
2591+
@export
2592+
@annotate_args([None, ([2, 3, 4], torch.bool, True)])
2593+
def forward(self, x):
2594+
return torch.ops.aten.count_nonzero(x, self.dim)
2595+
2596+
2597+
@register_test_case(module_factory=lambda: CountNonzeroDimIntListModuleBool())
2598+
def CountNonzeroDimIntListModuleBool_Basic(module, tu: TestUtils):
2599+
module.forward(tu.randint(2, 3, 4, low=0, high=2).to(torch.bool))

0 commit comments

Comments
 (0)