@@ -9234,6 +9234,40 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
9234
9234
" }\n"
9235
9235
" return %1 : !torch.list<int>\n"
9236
9236
" }\n"
9237
+ " func.func @\"__torch_mlir_shape_fn.aten.count_nonzero.dim_IntList\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9238
+ " %false = torch.constant.bool false\n"
9239
+ " %none = torch.constant.none\n"
9240
+ " %str = torch.constant.str \"AssertionError: \"\n"
9241
+ " %true = torch.constant.bool true\n"
9242
+ " %int0 = torch.constant.int 0\n"
9243
+ " %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
9244
+ " torch.prim.Loop %0, %true, init() {\n"
9245
+ " ^bb0(%arg2: !torch.int):\n"
9246
+ " %4 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
9247
+ " %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9248
+ " %6 = torch.aten.neg.int %5 : !torch.int -> !torch.int\n"
9249
+ " %7 = torch.aten.lt.int %4, %6 : !torch.int, !torch.int -> !torch.bool\n"
9250
+ " %8 = torch.prim.If %7 -> (!torch.bool) {\n"
9251
+ " torch.prim.If.yield %true : !torch.bool\n"
9252
+ " } else {\n"
9253
+ " %10 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9254
+ " %11 = torch.aten.ge.int %4, %10 : !torch.int, !torch.int -> !torch.bool\n"
9255
+ " torch.prim.If.yield %11 : !torch.bool\n"
9256
+ " }\n"
9257
+ " %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n"
9258
+ " torch.prim.If %9 -> () {\n"
9259
+ " torch.prim.If.yield\n"
9260
+ " } else {\n"
9261
+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9262
+ " torch.prim.If.yield\n"
9263
+ " }\n"
9264
+ " torch.prim.Loop.condition %true, iter()\n"
9265
+ " } : (!torch.int, !torch.bool) -> ()\n"
9266
+ " %1 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
9267
+ " %2 = torch.derefine %int0 : !torch.int to !torch.any\n"
9268
+ " %3 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %1, %false, %2) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
9269
+ " return %3 : !torch.list<int>\n"
9270
+ " }\n"
9237
9271
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
9238
9272
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9239
9273
" return %0 : !torch.list<int>\n"
0 commit comments