@@ -9286,6 +9286,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
9286
9286
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9287
9287
" return %0 : !torch.list<int>\n"
9288
9288
" }\n"
9289
+ " func.func @\"__torch_mlir_shape_fn.aten.logaddexp\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9290
+ " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9291
+ " return %0 : !torch.list<int>\n"
9292
+ " }\n"
9293
+ " func.func @\"__torch_mlir_shape_fn.aten.logaddexp2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9294
+ " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9295
+ " return %0 : !torch.list<int>\n"
9296
+ " }\n"
9289
9297
" func.func @\"__torch_mlir_shape_fn.aten.masked_fill.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
9290
9298
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9291
9299
" return %0 : !torch.list<int>\n"
@@ -12796,6 +12804,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
12796
12804
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
12797
12805
" return %arg3 : !torch.int\n"
12798
12806
" }\n"
12807
+ " func.func @\"__torch_mlir_dtype_fn.aten.logaddexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
12808
+ " %none = torch.constant.none\n"
12809
+ " %str = torch.constant.str \"AssertionError: \"\n"
12810
+ " %false = torch.constant.bool false\n"
12811
+ " %int11 = torch.constant.int 11\n"
12812
+ " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12813
+ " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12814
+ " %2 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
12815
+ " %3 = torch.prim.If %2 -> (!torch.bool) {\n"
12816
+ " %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
12817
+ " torch.prim.If.yield %4 : !torch.bool\n"
12818
+ " } else {\n"
12819
+ " torch.prim.If.yield %false : !torch.bool\n"
12820
+ " }\n"
12821
+ " torch.prim.If %3 -> () {\n"
12822
+ " torch.prim.If.yield\n"
12823
+ " } else {\n"
12824
+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12825
+ " torch.prim.If.yield\n"
12826
+ " }\n"
12827
+ " return %0#1 : !torch.int\n"
12828
+ " }\n"
12829
+ " func.func @\"__torch_mlir_dtype_fn.aten.logaddexp2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
12830
+ " %none = torch.constant.none\n"
12831
+ " %str = torch.constant.str \"AssertionError: \"\n"
12832
+ " %int10 = torch.constant.int 10\n"
12833
+ " %int9 = torch.constant.int 9\n"
12834
+ " %int8 = torch.constant.int 8\n"
12835
+ " %int11 = torch.constant.int 11\n"
12836
+ " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12837
+ " %1 = torch.prim.ListConstruct %int11, %int8, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
12838
+ " %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
12839
+ " %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
12840
+ " torch.prim.If %3 -> () {\n"
12841
+ " torch.prim.If.yield\n"
12842
+ " } else {\n"
12843
+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12844
+ " torch.prim.If.yield\n"
12845
+ " }\n"
12846
+ " return %0#1 : !torch.int\n"
12847
+ " }\n"
12799
12848
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
12800
12849
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12801
12850
" return %0#1 : !torch.int\n"
0 commit comments