@@ -15051,18 +15051,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
15051
15051
" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
15052
15052
" %none = torch.constant.none\n"
15053
15053
" %str = torch.constant.str \"AssertionError: \"\n"
15054
+ " %int3 = torch.constant.int 3\n"
15055
+ " %true = torch.constant.bool true\n"
15054
15056
" %int4 = torch.constant.int 4\n"
15055
15057
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15056
15058
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15057
15059
" %2 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
15058
- " torch.prim.If %2 -> () {\n"
15060
+ " %3 = torch.prim.If %2 -> (!torch.bool) {\n"
15061
+ " torch.prim.If.yield %true : !torch.bool\n"
15062
+ " } else {\n"
15063
+ " %5 = torch.aten.eq.int %1#1, %int3 : !torch.int, !torch.int -> !torch.bool\n"
15064
+ " torch.prim.If.yield %5 : !torch.bool\n"
15065
+ " }\n"
15066
+ " torch.prim.If %3 -> () {\n"
15059
15067
" torch.prim.If.yield\n"
15060
15068
" } else {\n"
15061
15069
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
15062
15070
" torch.prim.If.yield\n"
15063
15071
" }\n"
15064
- " %3 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
15065
- " return %3 : !torch.tuple<int, int>\n"
15072
+ " %4 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
15073
+ " return %4 : !torch.tuple<int, int>\n"
15066
15074
" }\n"
15067
15075
" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float) -> !torch.tuple<int, int, int> {\n"
15068
15076
" %int7 = torch.constant.int 7\n"
0 commit comments