@@ -11192,6 +11192,134 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1119211192" }\n"
1119311193" return %14 : !torch.list<int>\n"
1119411194" }\n"
11195+ " func.func @\"__torch_mlir_shape_fn.aten.stft.center\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.str, %arg7: !torch.bool, %arg8: !torch.optional<bool>, %arg9: !torch.optional<bool>, %arg10: !torch.optional<bool>) -> !torch.list<int> {\n"
11196+ " %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n"
11197+ " %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n"
11198+ " %false = torch.constant.bool false\n"
11199+ " %none = torch.constant.none\n"
11200+ " %str_1 = torch.constant.str \"AssertionError: Expected input tensor to be of shape (B?,L), where B is an optional batch dimension\"\n"
11201+ " %true = torch.constant.bool true\n"
11202+ " %int1 = torch.constant.int 1\n"
11203+ " %int2 = torch.constant.int 2\n"
11204+ " %int0 = torch.constant.int 0\n"
11205+ " %int4 = torch.constant.int 4\n"
11206+ " %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11207+ " %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
11208+ " %2 = torch.prim.If %1 -> (!torch.bool) {\n"
11209+ " torch.prim.If.yield %true : !torch.bool\n"
11210+ " } else {\n"
11211+ " %21 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11212+ " %22 = torch.aten.eq.int %21, %int2 : !torch.int, !torch.int -> !torch.bool\n"
11213+ " torch.prim.If.yield %22 : !torch.bool\n"
11214+ " }\n"
11215+ " torch.prim.If %2 -> () {\n"
11216+ " torch.prim.If.yield\n"
11217+ " } else {\n"
11218+ " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
11219+ " torch.prim.If.yield\n"
11220+ " }\n"
11221+ " %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11222+ " %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
11223+ " %5 = torch.prim.If %4 -> (!torch.optional<int>) {\n"
11224+ " %21 = torch.derefine %none : !torch.none to !torch.optional<int>\n"
11225+ " torch.prim.If.yield %21 : !torch.optional<int>\n"
11226+ " } else {\n"
11227+ " %21 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
11228+ " %22 = torch.derefine %21 : !torch.int to !torch.optional<int>\n"
11229+ " torch.prim.If.yield %22 : !torch.optional<int>\n"
11230+ " }\n"
11231+ " %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11232+ " %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n"
11233+ " %8 = torch.prim.If %7 -> (!torch.int) {\n"
11234+ " %21 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
11235+ " torch.prim.If.yield %21 : !torch.int\n"
11236+ " } else {\n"
11237+ " %21 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
11238+ " torch.prim.If.yield %21 : !torch.int\n"
11239+ " }\n"
11240+ " %9 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
11241+ " %10 = torch.prim.If %9 -> (!torch.int) {\n"
11242+ " %21 = torch.aten.floordiv.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n"
11243+ " torch.prim.If.yield %21 : !torch.int\n"
11244+ " } else {\n"
11245+ " %21 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
11246+ " torch.prim.If.yield %21 : !torch.int\n"
11247+ " }\n"
11248+ " %11 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11249+ " %12 = torch.prim.If %11 -> (!torch.bool) {\n"
11250+ " %21 = torch.aten.le.int %arg1, %8 : !torch.int, !torch.int -> !torch.bool\n"
11251+ " torch.prim.If.yield %21 : !torch.bool\n"
11252+ " } else {\n"
11253+ " torch.prim.If.yield %false : !torch.bool\n"
11254+ " }\n"
11255+ " torch.prim.If %12 -> () {\n"
11256+ " torch.prim.If.yield\n"
11257+ " } else {\n"
11258+ " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
11259+ " torch.prim.If.yield\n"
11260+ " }\n"
11261+ " %13 = torch.aten.gt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11262+ " torch.prim.If %13 -> () {\n"
11263+ " torch.prim.If.yield\n"
11264+ " } else {\n"
11265+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11266+ " torch.prim.If.yield\n"
11267+ " }\n"
11268+ " %14 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
11269+ " %15 = torch.aten.__isnot__ %5, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
11270+ " torch.prim.If %15 -> () {\n"
11271+ " %21 = torch.prim.unchecked_cast %5 : !torch.optional<int> -> !torch.int\n"
11272+ " %22 = torch.aten.append.t %14, %21 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11273+ " torch.prim.If.yield\n"
11274+ " } else {\n"
11275+ " torch.prim.If.yield\n"
11276+ " }\n"
11277+ " %16 = torch.aten.__is__ %arg8, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
11278+ " %17 = torch.prim.If %16 -> (!torch.bool) {\n"
11279+ " torch.prim.If.yield %true : !torch.bool\n"
11280+ " } else {\n"
11281+ " %21 = torch.prim.unchecked_cast %arg8 : !torch.optional<bool> -> !torch.bool\n"
11282+ " %22 = torch.aten.eq.bool %21, %true : !torch.bool, !torch.bool -> !torch.bool\n"
11283+ " torch.prim.If.yield %22 : !torch.bool\n"
11284+ " }\n"
11285+ " torch.prim.If %17 -> () {\n"
11286+ " %21 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n"
11287+ " %22 = torch.aten.add.int %21, %int1 : !torch.int, !torch.int -> !torch.int\n"
11288+ " %23 = torch.aten.append.t %14, %22 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11289+ " torch.prim.If.yield\n"
11290+ " } else {\n"
11291+ " %21 = torch.aten.append.t %14, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11292+ " torch.prim.If.yield\n"
11293+ " }\n"
11294+ " %18 = torch.aten.eq.bool %arg5, %true : !torch.bool, !torch.bool -> !torch.bool\n"
11295+ " torch.prim.If %18 -> () {\n"
11296+ " %21 = torch.aten.floordiv.int %8, %10 : !torch.int, !torch.int -> !torch.int\n"
11297+ " %22 = torch.aten.add.int %int1, %21 : !torch.int, !torch.int -> !torch.int\n"
11298+ " %23 = torch.aten.append.t %14, %22 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11299+ " torch.prim.If.yield\n"
11300+ " } else {\n"
11301+ " %21 = torch.aten.sub.int %8, %arg1 : !torch.int, !torch.int -> !torch.int\n"
11302+ " %22 = torch.aten.floordiv.int %21, %10 : !torch.int, !torch.int -> !torch.int\n"
11303+ " %23 = torch.aten.add.int %int1, %22 : !torch.int, !torch.int -> !torch.int\n"
11304+ " %24 = torch.aten.append.t %14, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11305+ " torch.prim.If.yield\n"
11306+ " }\n"
11307+ " %19 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
11308+ " %20 = torch.prim.If %19 -> (!torch.bool) {\n"
11309+ " %21 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
11310+ " %22 = torch.aten.eq.bool %21, %false : !torch.bool, !torch.bool -> !torch.bool\n"
11311+ " torch.prim.If.yield %22 : !torch.bool\n"
11312+ " } else {\n"
11313+ " torch.prim.If.yield %false : !torch.bool\n"
11314+ " }\n"
11315+ " torch.prim.If %20 -> () {\n"
11316+ " %21 = torch.aten.append.t %14, %int2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11317+ " torch.prim.If.yield\n"
11318+ " } else {\n"
11319+ " torch.prim.If.yield\n"
11320+ " }\n"
11321+ " return %14 : !torch.list<int>\n"
11322+ " }\n"
1119511323" func.func @\"__torch_mlir_shape_fn.aten.fft_ifft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
1119611324" return %arg0 : !torch.list<int>\n"
1119711325" }\n"
@@ -13413,6 +13541,143 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1341313541" }\n"
1341413542" return %6 : !torch.int\n"
1341513543" }\n"
13544+ " func.func @\"__torch_mlir_dtype_fn.aten.stft.center\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.str, %arg7: !torch.bool, %arg8: !torch.optional<bool>, %arg9: !torch.optional<bool>, %arg10: !torch.optional<bool>) -> !torch.int {\n"
13545+ " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
13546+ " %int7 = torch.constant.int 7\n"
13547+ " %int10 = torch.constant.int 10\n"
13548+ " %int6 = torch.constant.int 6\n"
13549+ " %int9 = torch.constant.int 9\n"
13550+ " %int5 = torch.constant.int 5\n"
13551+ " %int8 = torch.constant.int 8\n"
13552+ " %none = torch.constant.none\n"
13553+ " %false = torch.constant.bool false\n"
13554+ " %true = torch.constant.bool true\n"
13555+ " %0 = torch.prim.Uninitialized : !torch.int\n"
13556+ " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13557+ " %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13558+ " %3 = torch.prim.If %2 -> (!torch.bool) {\n"
13559+ " %7 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13560+ " torch.prim.If.yield %7 : !torch.bool\n"
13561+ " } else {\n"
13562+ " torch.prim.If.yield %false : !torch.bool\n"
13563+ " }\n"
13564+ " %4 = torch.prim.If %3 -> (!torch.bool) {\n"
13565+ " %7 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13566+ " torch.prim.If.yield %7 : !torch.bool\n"
13567+ " } else {\n"
13568+ " torch.prim.If.yield %false : !torch.bool\n"
13569+ " }\n"
13570+ " %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n"
13571+ " torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
13572+ " } else {\n"
13573+ " %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13574+ " %8 = torch.prim.If %7 -> (!torch.bool) {\n"
13575+ " %11 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13576+ " torch.prim.If.yield %11 : !torch.bool\n"
13577+ " } else {\n"
13578+ " torch.prim.If.yield %false : !torch.bool\n"
13579+ " }\n"
13580+ " %9 = torch.prim.If %8 -> (!torch.bool) {\n"
13581+ " %11 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13582+ " %12 = torch.aten.ne.bool %11, %true : !torch.bool, !torch.bool -> !torch.bool\n"
13583+ " torch.prim.If.yield %12 : !torch.bool\n"
13584+ " } else {\n"
13585+ " torch.prim.If.yield %false : !torch.bool\n"
13586+ " }\n"
13587+ " %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.int) {\n"
13588+ " %11 = torch.aten.eq.int %1#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
13589+ " %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.int) {\n"
13590+ " torch.prim.If.yield %true, %int5 : !torch.bool, !torch.int\n"
13591+ " } else {\n"
13592+ " %13 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
13593+ " %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
13594+ " torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n"
13595+ " } else {\n"
13596+ " %15 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
13597+ " %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
13598+ " torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n"
13599+ " } else {\n"
13600+ " torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
13601+ " }\n"
13602+ " torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
13603+ " }\n"
13604+ " torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
13605+ " }\n"
13606+ " torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.int\n"
13607+ " } else {\n"
13608+ " %11 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13609+ " %12 = torch.prim.If %11 -> (!torch.bool) {\n"
13610+ " %15 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13611+ " torch.prim.If.yield %15 : !torch.bool\n"
13612+ " } else {\n"
13613+ " torch.prim.If.yield %false : !torch.bool\n"
13614+ " }\n"
13615+ " %13 = torch.prim.If %12 -> (!torch.bool) {\n"
13616+ " %15 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13617+ " torch.prim.If.yield %15 : !torch.bool\n"
13618+ " } else {\n"
13619+ " torch.prim.If.yield %false : !torch.bool\n"
13620+ " }\n"
13621+ " %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
13622+ " %15 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
13623+ " %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
13624+ " torch.prim.If.yield %true, %int8 : !torch.bool, !torch.int\n"
13625+ " } else {\n"
13626+ " %17 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
13627+ " %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
13628+ " torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
13629+ " } else {\n"
13630+ " %19 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
13631+ " %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
13632+ " torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n"
13633+ " } else {\n"
13634+ " torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
13635+ " }\n"
13636+ " torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
13637+ " }\n"
13638+ " torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
13639+ " }\n"
13640+ " torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
13641+ " } else {\n"
13642+ " %15 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13643+ " %16 = torch.prim.If %15 -> (!torch.bool) {\n"
13644+ " %19 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13645+ " torch.prim.If.yield %19 : !torch.bool\n"
13646+ " } else {\n"
13647+ " torch.prim.If.yield %false : !torch.bool\n"
13648+ " }\n"
13649+ " %17 = torch.prim.If %16 -> (!torch.bool) {\n"
13650+ " %19 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13651+ " %20 = torch.aten.ne.bool %19, %true : !torch.bool, !torch.bool -> !torch.bool\n"
13652+ " torch.prim.If.yield %20 : !torch.bool\n"
13653+ " } else {\n"
13654+ " torch.prim.If.yield %false : !torch.bool\n"
13655+ " }\n"
13656+ " %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
13657+ " torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
13658+ " } else {\n"
13659+ " %19 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13660+ " %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
13661+ " torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
13662+ " } else {\n"
13663+ " torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
13664+ " }\n"
13665+ " torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
13666+ " }\n"
13667+ " torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
13668+ " }\n"
13669+ " torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
13670+ " }\n"
13671+ " torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.int\n"
13672+ " }\n"
13673+ " %6 = torch.prim.If %5#0 -> (!torch.int) {\n"
13674+ " torch.prim.If.yield %5#1 : !torch.int\n"
13675+ " } else {\n"
13676+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
13677+ " torch.prim.If.yield %0 : !torch.int\n"
13678+ " }\n"
13679+ " return %6 : !torch.int\n"
13680+ " }\n"
1341613681" func.func @\"__torch_mlir_dtype_fn.aten.fft_ifft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.int {\n"
1341713682" %none = torch.constant.none\n"
1341813683" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
0 commit comments