@@ -9201,6 +9201,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
9201
9201
" }\n"
9202
9202
" return %arg0 : !torch.list<int>\n"
9203
9203
" }\n"
9204
+ " func.func @\"__torch_mlir_shape_fn.aten.count_nonzero\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>) -> !torch.list<int> {\n"
9205
+ " %false = torch.constant.bool false\n"
9206
+ " %str = torch.constant.str \"AssertionError: \"\n"
9207
+ " %true = torch.constant.bool true\n"
9208
+ " %none = torch.constant.none\n"
9209
+ " %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
9210
+ " %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
9211
+ " %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
9212
+ " torch.prim.If.yield %2 : !torch.list<int>\n"
9213
+ " } else {\n"
9214
+ " %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
9215
+ " %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9216
+ " %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int\n"
9217
+ " %5 = torch.aten.lt.int %2, %4 : !torch.int, !torch.int -> !torch.bool\n"
9218
+ " %6 = torch.prim.If %5 -> (!torch.bool) {\n"
9219
+ " torch.prim.If.yield %true : !torch.bool\n"
9220
+ " } else {\n"
9221
+ " %9 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9222
+ " %10 = torch.aten.ge.int %2, %9 : !torch.int, !torch.int -> !torch.bool\n"
9223
+ " torch.prim.If.yield %10 : !torch.bool\n"
9224
+ " }\n"
9225
+ " %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
9226
+ " torch.prim.If %7 -> () {\n"
9227
+ " torch.prim.If.yield\n"
9228
+ " } else {\n"
9229
+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9230
+ " torch.prim.If.yield\n"
9231
+ " }\n"
9232
+ " %8 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %false) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
9233
+ " torch.prim.If.yield %8 : !torch.list<int>\n"
9234
+ " }\n"
9235
+ " return %1 : !torch.list<int>\n"
9236
+ " }\n"
9204
9237
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
9205
9238
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9206
9239
" return %0 : !torch.list<int>\n"
@@ -15821,6 +15854,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
15821
15854
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15822
15855
" return %0#1 : !torch.int\n"
15823
15856
" }\n"
15857
+ " func.func @\"__torch_mlir_dtype_fn.aten.count_nonzero\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
15858
+ " %int4 = torch.constant.int 4\n"
15859
+ " return %int4 : !torch.int\n"
15860
+ " }\n"
15824
15861
" func.func @\"__torch_mlir_dtype_fn.aten.rot90\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
15825
15862
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15826
15863
" return %0#1 : !torch.int\n"
0 commit comments