@@ -7504,6 +7504,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
7504
7504
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
7505
7505
" return %1 : !torch.list<int>\n"
7506
7506
" }\n"
7507
+ " func.func @\"__torch_mlir_shape_fn.aten.any.dims\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.list<int> {\n"
7508
+ " %none = torch.constant.none\n"
7509
+ " %0 = torch.derefine %none : !torch.none to !torch.any\n"
7510
+ " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
7511
+ " return %1 : !torch.list<int>\n"
7512
+ " }\n"
7507
7513
" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
7508
7514
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
7509
7515
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
@@ -15420,6 +15426,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
15420
15426
" }\n"
15421
15427
" return %2 : !torch.int\n"
15422
15428
" }\n"
15429
+ " func.func @\"__torch_mlir_dtype_fn.aten.any.dims\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.int {\n"
15430
+ " %int11 = torch.constant.int 11\n"
15431
+ " %int0 = torch.constant.int 0\n"
15432
+ " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15433
+ " %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
15434
+ " %2 = torch.prim.If %1 -> (!torch.int) {\n"
15435
+ " torch.prim.If.yield %0#1 : !torch.int\n"
15436
+ " } else {\n"
15437
+ " torch.prim.If.yield %int11 : !torch.int\n"
15438
+ " }\n"
15439
+ " return %2 : !torch.int\n"
15440
+ " }\n"
15423
15441
" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
15424
15442
" %int11 = torch.constant.int 11\n"
15425
15443
" %int0 = torch.constant.int 0\n"
0 commit comments