Skip to content

Commit cccb0c0

Browse files
authored
Add aten.stft.center and decomposition (#3880)
The choice to work with `aten.stft.center` instead of `aten.stft` is because the latter doesn't match the signature that gets exposed (see https://pytorch.org/docs/stable/generated/torch.stft.html).
1 parent 5e5a74a commit cccb0c0

File tree

12 files changed

+1345
-8
lines changed

12 files changed

+1345
-8
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13730,6 +13730,40 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
1373013730
printDefaultTorchOp(printer, *this, 9, 1);
1373113731
}
1373213732
}];
13733+
let hasCanonicalizer = 1;
13734+
}
13735+
13736+
def Torch_AtenStftCenterOp : Torch_Op<"aten.stft.center", [
13737+
AllowsTypeRefinement,
13738+
HasValueSemantics,
13739+
ReadOnly
13740+
]> {
13741+
let summary = "Generated op for `aten::stft.center : (Tensor, int, int?, int?, Tensor?, bool, str, bool, bool?, bool?, bool?) -> (Tensor)`";
13742+
let arguments = (ins
13743+
AnyTorchTensorType:$self,
13744+
Torch_IntType:$n_fft,
13745+
AnyTorchOptionalIntType:$hop_length,
13746+
AnyTorchOptionalIntType:$win_length,
13747+
AnyTorchOptionalTensorType:$window,
13748+
Torch_BoolType:$center,
13749+
Torch_StringType:$pad_mode,
13750+
Torch_BoolType:$normalized,
13751+
AnyTorchOptionalBoolType:$onesided,
13752+
AnyTorchOptionalBoolType:$return_complex,
13753+
AnyTorchOptionalBoolType:$align_to_window
13754+
);
13755+
let results = (outs
13756+
AnyTorchOptionalTensorType:$result
13757+
);
13758+
let hasCustomAssemblyFormat = 1;
13759+
let extraClassDefinition = [{
13760+
ParseResult AtenStftCenterOp::parse(OpAsmParser &parser, OperationState &result) {
13761+
return parseDefaultTorchOp(parser, result, 11, 1);
13762+
}
13763+
void AtenStftCenterOp::print(OpAsmPrinter &printer) {
13764+
printDefaultTorchOp(printer, *this, 11, 1);
13765+
}
13766+
}];
1373313767
}
1373413768

1373513769
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3874,6 +3874,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
38743874
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
38753875
Value trueVal =
38763876
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
3877+
Value padMode =
3878+
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "reflect");
38773879
auto stftTy = complexSignalTy.getWithSizesAndDtype(
38783880
ArrayRef<int64_t>({resultShape[0], resultShape[2], resultShape[1]}),
38793881
complexSignalTy.getDtype());
@@ -3886,10 +3888,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
38863888
// shape of stft to match the shape of resultType. Also, it is
38873889
// immaterial whether torch.view_as_real is called after or before the
38883890
// permutation; both outputs will be equivalent.
3889-
Value stft = rewriter.create<Torch::AtenStftOp>(
3891+
Value stft = rewriter.create<Torch::AtenStftCenterOp>(
38903892
binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem,
3891-
windowLen, window, falseVal, onesided ? trueVal : falseVal, trueVal,
3892-
falseVal);
3893+
windowLen, window, falseVal, padMode, falseVal,
3894+
onesided ? trueVal : falseVal, trueVal, falseVal);
38933895

38943896
auto permuteStftTy = complexSignalTy.getWithSizesAndDtype(
38953897
ArrayRef<int64_t>({resultShape[0], resultShape[1], resultShape[2]}),

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3737,6 +3737,26 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
37373737
return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs));
37383738
}
37393739

3740+
//===----------------------------------------------------------------------===//
3741+
// AtenStftOp
3742+
//===----------------------------------------------------------------------===//
3743+
3744+
void AtenStftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3745+
MLIRContext *context) {
3746+
patterns.add(+[](AtenStftOp op, PatternRewriter &rewriter) {
3747+
Value falseVal = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
3748+
Value padMode =
3749+
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "reflect");
3750+
rewriter.replaceOpWithNewOp<AtenStftCenterOp>(
3751+
op, op.getType(), op.getSelf(), op.getNFft(), op.getHopLength(),
3752+
op.getWinLength(), op.getWindow(), falseVal, padMode,
3753+
op.getNormalized(), op.getOnesided(), op.getReturnComplex(),
3754+
op.getAlignToWindow());
3755+
3756+
return success();
3757+
});
3758+
}
3759+
37403760
//===----------------------------------------------------------------------===//
37413761
// AtenAliasOp
37423762
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11192,6 +11192,134 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1119211192
" }\n"
1119311193
" return %14 : !torch.list<int>\n"
1119411194
" }\n"
11195+
" func.func @\"__torch_mlir_shape_fn.aten.stft.center\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.str, %arg7: !torch.bool, %arg8: !torch.optional<bool>, %arg9: !torch.optional<bool>, %arg10: !torch.optional<bool>) -> !torch.list<int> {\n"
11196+
" %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n"
11197+
" %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n"
11198+
" %false = torch.constant.bool false\n"
11199+
" %none = torch.constant.none\n"
11200+
" %str_1 = torch.constant.str \"AssertionError: Expected input tensor to be of shape (B?,L), where B is an optional batch dimension\"\n"
11201+
" %true = torch.constant.bool true\n"
11202+
" %int1 = torch.constant.int 1\n"
11203+
" %int2 = torch.constant.int 2\n"
11204+
" %int0 = torch.constant.int 0\n"
11205+
" %int4 = torch.constant.int 4\n"
11206+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11207+
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
11208+
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
11209+
" torch.prim.If.yield %true : !torch.bool\n"
11210+
" } else {\n"
11211+
" %21 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11212+
" %22 = torch.aten.eq.int %21, %int2 : !torch.int, !torch.int -> !torch.bool\n"
11213+
" torch.prim.If.yield %22 : !torch.bool\n"
11214+
" }\n"
11215+
" torch.prim.If %2 -> () {\n"
11216+
" torch.prim.If.yield\n"
11217+
" } else {\n"
11218+
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
11219+
" torch.prim.If.yield\n"
11220+
" }\n"
11221+
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11222+
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
11223+
" %5 = torch.prim.If %4 -> (!torch.optional<int>) {\n"
11224+
" %21 = torch.derefine %none : !torch.none to !torch.optional<int>\n"
11225+
" torch.prim.If.yield %21 : !torch.optional<int>\n"
11226+
" } else {\n"
11227+
" %21 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
11228+
" %22 = torch.derefine %21 : !torch.int to !torch.optional<int>\n"
11229+
" torch.prim.If.yield %22 : !torch.optional<int>\n"
11230+
" }\n"
11231+
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
11232+
" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n"
11233+
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
11234+
" %21 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
11235+
" torch.prim.If.yield %21 : !torch.int\n"
11236+
" } else {\n"
11237+
" %21 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
11238+
" torch.prim.If.yield %21 : !torch.int\n"
11239+
" }\n"
11240+
" %9 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
11241+
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
11242+
" %21 = torch.aten.floordiv.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n"
11243+
" torch.prim.If.yield %21 : !torch.int\n"
11244+
" } else {\n"
11245+
" %21 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
11246+
" torch.prim.If.yield %21 : !torch.int\n"
11247+
" }\n"
11248+
" %11 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11249+
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
11250+
" %21 = torch.aten.le.int %arg1, %8 : !torch.int, !torch.int -> !torch.bool\n"
11251+
" torch.prim.If.yield %21 : !torch.bool\n"
11252+
" } else {\n"
11253+
" torch.prim.If.yield %false : !torch.bool\n"
11254+
" }\n"
11255+
" torch.prim.If %12 -> () {\n"
11256+
" torch.prim.If.yield\n"
11257+
" } else {\n"
11258+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
11259+
" torch.prim.If.yield\n"
11260+
" }\n"
11261+
" %13 = torch.aten.gt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11262+
" torch.prim.If %13 -> () {\n"
11263+
" torch.prim.If.yield\n"
11264+
" } else {\n"
11265+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11266+
" torch.prim.If.yield\n"
11267+
" }\n"
11268+
" %14 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
11269+
" %15 = torch.aten.__isnot__ %5, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
11270+
" torch.prim.If %15 -> () {\n"
11271+
" %21 = torch.prim.unchecked_cast %5 : !torch.optional<int> -> !torch.int\n"
11272+
" %22 = torch.aten.append.t %14, %21 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11273+
" torch.prim.If.yield\n"
11274+
" } else {\n"
11275+
" torch.prim.If.yield\n"
11276+
" }\n"
11277+
" %16 = torch.aten.__is__ %arg8, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
11278+
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
11279+
" torch.prim.If.yield %true : !torch.bool\n"
11280+
" } else {\n"
11281+
" %21 = torch.prim.unchecked_cast %arg8 : !torch.optional<bool> -> !torch.bool\n"
11282+
" %22 = torch.aten.eq.bool %21, %true : !torch.bool, !torch.bool -> !torch.bool\n"
11283+
" torch.prim.If.yield %22 : !torch.bool\n"
11284+
" }\n"
11285+
" torch.prim.If %17 -> () {\n"
11286+
" %21 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n"
11287+
" %22 = torch.aten.add.int %21, %int1 : !torch.int, !torch.int -> !torch.int\n"
11288+
" %23 = torch.aten.append.t %14, %22 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11289+
" torch.prim.If.yield\n"
11290+
" } else {\n"
11291+
" %21 = torch.aten.append.t %14, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11292+
" torch.prim.If.yield\n"
11293+
" }\n"
11294+
" %18 = torch.aten.eq.bool %arg5, %true : !torch.bool, !torch.bool -> !torch.bool\n"
11295+
" torch.prim.If %18 -> () {\n"
11296+
" %21 = torch.aten.floordiv.int %8, %10 : !torch.int, !torch.int -> !torch.int\n"
11297+
" %22 = torch.aten.add.int %int1, %21 : !torch.int, !torch.int -> !torch.int\n"
11298+
" %23 = torch.aten.append.t %14, %22 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11299+
" torch.prim.If.yield\n"
11300+
" } else {\n"
11301+
" %21 = torch.aten.sub.int %8, %arg1 : !torch.int, !torch.int -> !torch.int\n"
11302+
" %22 = torch.aten.floordiv.int %21, %10 : !torch.int, !torch.int -> !torch.int\n"
11303+
" %23 = torch.aten.add.int %int1, %22 : !torch.int, !torch.int -> !torch.int\n"
11304+
" %24 = torch.aten.append.t %14, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11305+
" torch.prim.If.yield\n"
11306+
" }\n"
11307+
" %19 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
11308+
" %20 = torch.prim.If %19 -> (!torch.bool) {\n"
11309+
" %21 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
11310+
" %22 = torch.aten.eq.bool %21, %false : !torch.bool, !torch.bool -> !torch.bool\n"
11311+
" torch.prim.If.yield %22 : !torch.bool\n"
11312+
" } else {\n"
11313+
" torch.prim.If.yield %false : !torch.bool\n"
11314+
" }\n"
11315+
" torch.prim.If %20 -> () {\n"
11316+
" %21 = torch.aten.append.t %14, %int2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11317+
" torch.prim.If.yield\n"
11318+
" } else {\n"
11319+
" torch.prim.If.yield\n"
11320+
" }\n"
11321+
" return %14 : !torch.list<int>\n"
11322+
" }\n"
1119511323
" func.func @\"__torch_mlir_shape_fn.aten.fft_ifft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
1119611324
" return %arg0 : !torch.list<int>\n"
1119711325
" }\n"
@@ -13413,6 +13541,143 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1341313541
" }\n"
1341413542
" return %6 : !torch.int\n"
1341513543
" }\n"
13544+
" func.func @\"__torch_mlir_dtype_fn.aten.stft.center\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.str, %arg7: !torch.bool, %arg8: !torch.optional<bool>, %arg9: !torch.optional<bool>, %arg10: !torch.optional<bool>) -> !torch.int {\n"
13545+
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
13546+
" %int7 = torch.constant.int 7\n"
13547+
" %int10 = torch.constant.int 10\n"
13548+
" %int6 = torch.constant.int 6\n"
13549+
" %int9 = torch.constant.int 9\n"
13550+
" %int5 = torch.constant.int 5\n"
13551+
" %int8 = torch.constant.int 8\n"
13552+
" %none = torch.constant.none\n"
13553+
" %false = torch.constant.bool false\n"
13554+
" %true = torch.constant.bool true\n"
13555+
" %0 = torch.prim.Uninitialized : !torch.int\n"
13556+
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13557+
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13558+
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
13559+
" %7 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13560+
" torch.prim.If.yield %7 : !torch.bool\n"
13561+
" } else {\n"
13562+
" torch.prim.If.yield %false : !torch.bool\n"
13563+
" }\n"
13564+
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
13565+
" %7 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13566+
" torch.prim.If.yield %7 : !torch.bool\n"
13567+
" } else {\n"
13568+
" torch.prim.If.yield %false : !torch.bool\n"
13569+
" }\n"
13570+
" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n"
13571+
" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
13572+
" } else {\n"
13573+
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13574+
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
13575+
" %11 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13576+
" torch.prim.If.yield %11 : !torch.bool\n"
13577+
" } else {\n"
13578+
" torch.prim.If.yield %false : !torch.bool\n"
13579+
" }\n"
13580+
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
13581+
" %11 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13582+
" %12 = torch.aten.ne.bool %11, %true : !torch.bool, !torch.bool -> !torch.bool\n"
13583+
" torch.prim.If.yield %12 : !torch.bool\n"
13584+
" } else {\n"
13585+
" torch.prim.If.yield %false : !torch.bool\n"
13586+
" }\n"
13587+
" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.int) {\n"
13588+
" %11 = torch.aten.eq.int %1#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
13589+
" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.int) {\n"
13590+
" torch.prim.If.yield %true, %int5 : !torch.bool, !torch.int\n"
13591+
" } else {\n"
13592+
" %13 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
13593+
" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
13594+
" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n"
13595+
" } else {\n"
13596+
" %15 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
13597+
" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
13598+
" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n"
13599+
" } else {\n"
13600+
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
13601+
" }\n"
13602+
" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
13603+
" }\n"
13604+
" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
13605+
" }\n"
13606+
" torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.int\n"
13607+
" } else {\n"
13608+
" %11 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13609+
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
13610+
" %15 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13611+
" torch.prim.If.yield %15 : !torch.bool\n"
13612+
" } else {\n"
13613+
" torch.prim.If.yield %false : !torch.bool\n"
13614+
" }\n"
13615+
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
13616+
" %15 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13617+
" torch.prim.If.yield %15 : !torch.bool\n"
13618+
" } else {\n"
13619+
" torch.prim.If.yield %false : !torch.bool\n"
13620+
" }\n"
13621+
" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
13622+
" %15 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
13623+
" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
13624+
" torch.prim.If.yield %true, %int8 : !torch.bool, !torch.int\n"
13625+
" } else {\n"
13626+
" %17 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
13627+
" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
13628+
" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
13629+
" } else {\n"
13630+
" %19 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
13631+
" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
13632+
" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n"
13633+
" } else {\n"
13634+
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
13635+
" }\n"
13636+
" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
13637+
" }\n"
13638+
" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
13639+
" }\n"
13640+
" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
13641+
" } else {\n"
13642+
" %15 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13643+
" %16 = torch.prim.If %15 -> (!torch.bool) {\n"
13644+
" %19 = torch.aten.__isnot__ %arg9, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
13645+
" torch.prim.If.yield %19 : !torch.bool\n"
13646+
" } else {\n"
13647+
" torch.prim.If.yield %false : !torch.bool\n"
13648+
" }\n"
13649+
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
13650+
" %19 = torch.prim.unchecked_cast %arg9 : !torch.optional<bool> -> !torch.bool\n"
13651+
" %20 = torch.aten.ne.bool %19, %true : !torch.bool, !torch.bool -> !torch.bool\n"
13652+
" torch.prim.If.yield %20 : !torch.bool\n"
13653+
" } else {\n"
13654+
" torch.prim.If.yield %false : !torch.bool\n"
13655+
" }\n"
13656+
" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
13657+
" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
13658+
" } else {\n"
13659+
" %19 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
13660+
" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
13661+
" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
13662+
" } else {\n"
13663+
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
13664+
" }\n"
13665+
" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
13666+
" }\n"
13667+
" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
13668+
" }\n"
13669+
" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
13670+
" }\n"
13671+
" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.int\n"
13672+
" }\n"
13673+
" %6 = torch.prim.If %5#0 -> (!torch.int) {\n"
13674+
" torch.prim.If.yield %5#1 : !torch.int\n"
13675+
" } else {\n"
13676+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
13677+
" torch.prim.If.yield %0 : !torch.int\n"
13678+
" }\n"
13679+
" return %6 : !torch.int\n"
13680+
" }\n"
1341613681
" func.func @\"__torch_mlir_dtype_fn.aten.fft_ifft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.int {\n"
1341713682
" %none = torch.constant.none\n"
1341813683
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"

0 commit comments

Comments
 (0)