@@ -7810,6 +7810,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
78107810" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
78117811" return %0 : !torch.list<int>\n"
78127812" }\n"
7813+ " func.func @\"__torch_mlir_shape_fn.aten.broadcast_tensors\"(%arg0: !torch.list<list<int>>) -> !torch.list<list<int>> {\n"
7814+ " %true = torch.constant.bool true\n"
7815+ " %int0 = torch.constant.int 0\n"
7816+ " %int1 = torch.constant.int 1\n"
7817+ " %0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7818+ " %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
7819+ " %2 = torch.prim.If %1 -> (!torch.list<list<int>>) {\n"
7820+ " %3 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
7821+ " torch.prim.If.yield %3 : !torch.list<list<int>>\n"
7822+ " } else {\n"
7823+ " %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
7824+ " %4 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7825+ " %5 = torch.aten.__range_length %int1, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
7826+ " %6 = torch.prim.Loop %5, %true, init(%3) {\n"
7827+ " ^bb0(%arg1: !torch.int, %arg2: !torch.list<int>):\n"
7828+ " %9 = torch.aten.__derive_index %arg1, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
7829+ " %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
7830+ " %11 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg2, %10) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
7831+ " torch.prim.Loop.condition %true, iter(%11 : !torch.list<int>)\n"
7832+ " } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
7833+ " %7 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
7834+ " %8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
7835+ " torch.prim.Loop %8, %true, init() {\n"
7836+ " ^bb0(%arg1: !torch.int):\n"
7837+ " %9 = torch.aten.append.t %7, %6 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
7838+ " torch.prim.Loop.condition %true, iter()\n"
7839+ " } : (!torch.int, !torch.bool) -> ()\n"
7840+ " torch.prim.If.yield %7 : !torch.list<list<int>>\n"
7841+ " }\n"
7842+ " return %2 : !torch.list<list<int>>\n"
7843+ " }\n"
78137844" func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
78147845" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
78157846" return %0 : !torch.list<int>\n"
@@ -12556,6 +12587,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1255612587" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1255712588" return %0#1 : !torch.int\n"
1255812589" }\n"
12590+ " func.func @\"__torch_mlir_dtype_fn.aten.broadcast_tensors\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.list<tuple<int, int>> {\n"
12591+ " %true = torch.constant.bool true\n"
12592+ " %int0 = torch.constant.int 0\n"
12593+ " %0 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
12594+ " %1 = torch.prim.Loop %0, %true, init(%int0) {\n"
12595+ " ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n"
12596+ " %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
12597+ " %5 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12598+ " %6 = torch.aten.gt.int %5, %arg2 : !torch.int, !torch.int -> !torch.bool\n"
12599+ " %7 = torch.prim.If %6 -> (!torch.int) {\n"
12600+ " %8 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12601+ " torch.prim.If.yield %8 : !torch.int\n"
12602+ " } else {\n"
12603+ " torch.prim.If.yield %arg2 : !torch.int\n"
12604+ " }\n"
12605+ " torch.prim.Loop.condition %true, iter(%7 : !torch.int)\n"
12606+ " } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
12607+ " %2 = torch.prim.ListConstruct : () -> !torch.list<tuple<int, int>>\n"
12608+ " %3 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
12609+ " torch.prim.Loop %3, %true, init() {\n"
12610+ " ^bb0(%arg1: !torch.int):\n"
12611+ " %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
12612+ " %5:2 = torch.prim.TupleUnpack %4 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12613+ " %6 = torch.prim.TupleConstruct %1, %5#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
12614+ " %7 = torch.aten.append.t %2, %6 : !torch.list<tuple<int, int>>, !torch.tuple<int, int> -> !torch.list<tuple<int, int>>\n"
12615+ " torch.prim.Loop.condition %true, iter()\n"
12616+ " } : (!torch.int, !torch.bool) -> ()\n"
12617+ " return %2 : !torch.list<tuple<int, int>>\n"
12618+ " }\n"
1255912619" func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
1256012620" %int7 = torch.constant.int 7\n"
1256112621" %int6 = torch.constant.int 6\n"
0 commit comments