Skip to content

Commit 5a5cc6b

Browse files
[MLIR][TORCH] Add aten.special.expm1 op lowering (llvm#3878)
This commit adds the support for torch.aten.special.expm1 op by decomposing it into torch.aten.expm1 op. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 31b912e commit 5a5cc6b

File tree

8 files changed

+112
-6
lines changed

8 files changed

+112
-6
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4610,6 +4610,29 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [
46104610
}];
46114611
}
46124612

4613+
def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [
4614+
AllowsTypeRefinement,
4615+
HasValueSemantics,
4616+
ReadOnly
4617+
]> {
4618+
let summary = "Generated op for `aten::special_expm1 : (Tensor) -> (Tensor)`";
4619+
let arguments = (ins
4620+
AnyTorchTensorType:$self
4621+
);
4622+
let results = (outs
4623+
AnyTorchOptionalTensorType:$result
4624+
);
4625+
let hasCustomAssemblyFormat = 1;
4626+
let extraClassDefinition = [{
4627+
ParseResult AtenSpecialExpm1Op::parse(OpAsmParser &parser, OperationState &result) {
4628+
return parseDefaultTorchOp(parser, result, 1, 1);
4629+
}
4630+
void AtenSpecialExpm1Op::print(OpAsmPrinter &printer) {
4631+
printDefaultTorchOp(printer, *this, 1, 1);
4632+
}
4633+
}];
4634+
}
4635+
46134636
def Torch_AtenSignOp : Torch_Op<"aten.sign", [
46144637
AllowsTypeRefinement,
46154638
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6495,6 +6495,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
64956495
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
64966496
" return %0 : !torch.list<int>\n"
64976497
" }\n"
6498+
" func.func @\"__torch_mlir_shape_fn.aten.special_expm1\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6499+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6500+
" return %0 : !torch.list<int>\n"
6501+
" }\n"
64986502
" func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
64996503
" return %arg0 : !torch.list<int>\n"
65006504
" }\n"
@@ -11589,6 +11593,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1158911593
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
1159011594
" return %1 : !torch.int\n"
1159111595
" }\n"
11596+
" func.func @\"__torch_mlir_dtype_fn.aten.special_expm1\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
11597+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11598+
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
11599+
" return %1 : !torch.int\n"
11600+
" }\n"
1159211601
" func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1159311602
" %int11 = torch.constant.int 11\n"
1159411603
" return %int11 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11177,6 +11177,19 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
1117711177
};
1117811178
} // namespace
1117911179

11180+
namespace {
11181+
class DecomposeAtenSpecialExpm1Op
11182+
: public OpRewritePattern<AtenSpecialExpm1Op> {
11183+
public:
11184+
using OpRewritePattern<AtenSpecialExpm1Op>::OpRewritePattern;
11185+
LogicalResult matchAndRewrite(AtenSpecialExpm1Op op,
11186+
PatternRewriter &rewriter) const override {
11187+
rewriter.replaceOpWithNewOp<AtenExpm1Op>(op, op.getType(), op.getSelf());
11188+
return success();
11189+
}
11190+
};
11191+
} // namespace
11192+
1118011193
namespace {
1118111194
class DecomposeComplexOpsPass
1118211195
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -11462,6 +11475,7 @@ class DecomposeComplexOpsPass
1146211475
addPatternIfTargetOpIsIllegal<DecomposeAtenThresholdOp>(patterns);
1146311476
addPatternIfTargetOpIsIllegal<DecomposeAtenFloatPowerTensorTensorOp>(
1146411477
patterns);
11478+
addPatternIfTargetOpIsIllegal<DecomposeAtenSpecialExpm1Op>(patterns);
1146511479

1146611480
addPatternIfTargetOpIsIllegal<
1146711481
DecomposeAtenFMaxMinOp<AtenFmaxOp, AtenMaximumOp>>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
569569
target.addIllegalOp<AtenLinalgNormOp>();
570570
target.addIllegalOp<AtenFminOp>();
571571
target.addIllegalOp<AtenFmaxOp>();
572+
target.addIllegalOp<AtenSpecialExpm1Op>();
572573

573574
for (auto &opName : backendLegalOpsSet) {
574575
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,6 @@
500500
"AdaptiveMaxPool1dStatic_basic",
501501
"CrossEntropyLossModule_basic",
502502
"CrossEntropyLossNoReductionModule_basic",
503-
"ElementwiseExpm1IntModule_basic",
504-
"ElementwiseExpm1Module_basic",
505503
"IsInfiniteModule_basic",
506504
"InterpolateDynamicModule_sizes_nearest",
507505
"IouOfModule_basic",
@@ -909,8 +907,6 @@
909907
"AtenItemIntOpModule_basic",
910908
"CrossEntropyLossModule_basic",
911909
"CrossEntropyLossNoReductionModule_basic",
912-
"ElementwiseExpm1IntModule_basic",
913-
"ElementwiseExpm1Module_basic",
914910
"InterpolateDynamicModule_sizes_nearest",
915911
"IouOfModule_basic",
916912
"IscloseStaticModuleTrue_basic",
@@ -1209,6 +1205,8 @@
12091205
"ElementwiseRsqrtModule_basic",
12101206
"ElementwiseSigmoidModule_basic",
12111207
"ElementwiseSinModule_basic",
1208+
"ElementwiseSpecialExpm1IntModule_basic",
1209+
"ElementwiseSpecialExpm1Module_basic",
12121210
"ElementwiseSqrtModule_basic",
12131211
"ElementwiseTanIntModule_basic",
12141212
"ElementwiseTanModule_basic",
@@ -2951,6 +2949,8 @@
29512949
"ElementwiseEluNonDefaultModule_basic",
29522950
"ElementwiseExpm1IntModule_basic",
29532951
"ElementwiseExpm1Module_basic",
2952+
"ElementwiseSpecialExpm1IntModule_basic",
2953+
"ElementwiseSpecialExpm1Module_basic",
29542954
"ElementwiseFmodTensor_Int_basic",
29552955
"ElementwiseCreateComplexModule_basic",
29562956
"ElementwiseMulTensorComplexModule_basic",
@@ -3662,6 +3662,8 @@
36623662
"ElementwiseQuantizePerTensorUIntModule_basic",
36633663
"ElementwiseSinhIntModule_basic",
36643664
"ElementwiseSinhModule_basic",
3665+
"ElementwiseSpecialExpm1IntModule_basic",
3666+
"ElementwiseSpecialExpm1Module_basic",
36653667
"ElementwiseToDtypeF32ToI64Module_basic",
36663668
"ElementwiseToDtypeI64ToUI8Module_basic",
36673669
"ElementwiseWhereScalarOtherStaticModule_basic",
@@ -4355,6 +4357,8 @@
43554357
"ElementwiseSinIntModule_basic",
43564358
"ElementwiseSinhIntModule_basic",
43574359
"ElementwiseSinhModule_basic",
4360+
"ElementwiseSpecialExpm1IntModule_basic",
4361+
"ElementwiseSpecialExpm1Module_basic",
43584362
"ElementwiseSqrtIntModule_basic",
43594363
"ElementwiseSubScalarIntModule_basic",
43604364
"ElementwiseTanIntModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ def aten〇exp2〡shape(self: List[int]) -> List[int]:
222222
def aten〇expm1〡shape(self: List[int]) -> List[int]:
223223
return upstream_shape_functions.unary(self)
224224

225+
def aten〇special_expm1〡shape(self: List[int]) -> List[int]:
226+
return upstream_shape_functions.unary(self)
227+
225228
def aten〇isfinite〡shape(self: List[int]) -> List[int]:
226229
return self
227230

@@ -2717,6 +2720,11 @@ def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
27172720
self_rank, self_dtype = self_rank_dtype
27182721
return _get_dtype_of_floating_point_op(self_dtype)
27192722

2723+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
2724+
def aten〇special_expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
2725+
self_rank, self_dtype = self_rank_dtype
2726+
return _get_dtype_of_floating_point_op(self_dtype)
2727+
27202728
def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
27212729
return torch.bool
27222730

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def emit_with_mutating_variants(key, **kwargs):
452452
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
453453
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
454454
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
455+
emit("aten::special_expm1 : (Tensor) -> (Tensor)")
455456
emit_with_mutating_variants(
456457
"aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True
457458
)

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5207,7 +5207,7 @@ def __init__(self):
52075207
]
52085208
)
52095209
def forward(self, a):
5210-
return torch.special.expm1(a)
5210+
return torch.expm1(a)
52115211

52125212

52135213
@register_test_case(module_factory=lambda: ElementwiseExpm1Module())
@@ -5230,7 +5230,7 @@ def __init__(self):
52305230
]
52315231
)
52325232
def forward(self, a):
5233-
return torch.special.expm1(a)
5233+
return torch.expm1(a)
52345234

52355235

52365236
@register_test_case(module_factory=lambda: ElementwiseExpm1IntModule())
@@ -5241,6 +5241,52 @@ def ElementwiseExpm1IntModule_basic(module, tu: TestUtils):
52415241
# ==============================================================================
52425242

52435243

5244+
class ElementwiseSpecialExpm1Module(torch.nn.Module):
5245+
def __init__(self):
5246+
super().__init__()
5247+
5248+
@export
5249+
@annotate_args(
5250+
[
5251+
None,
5252+
([-1, -1], torch.float32, True),
5253+
]
5254+
)
5255+
def forward(self, a):
5256+
return torch.special.expm1(a)
5257+
5258+
5259+
@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1Module())
5260+
def ElementwiseSpecialExpm1Module_basic(module, tu: TestUtils):
5261+
module.forward(tu.rand(3, 4))
5262+
5263+
5264+
# ==============================================================================
5265+
5266+
5267+
class ElementwiseSpecialExpm1IntModule(torch.nn.Module):
5268+
def __init__(self):
5269+
super().__init__()
5270+
5271+
@export
5272+
@annotate_args(
5273+
[
5274+
None,
5275+
([-1, -1], torch.int32, True),
5276+
]
5277+
)
5278+
def forward(self, a):
5279+
return torch.special.expm1(a)
5280+
5281+
5282+
@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1IntModule())
5283+
def ElementwiseSpecialExpm1IntModule_basic(module, tu: TestUtils):
5284+
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
5285+
5286+
5287+
# ==============================================================================
5288+
5289+
52445290
class ElementwiseRad2DegModule(torch.nn.Module):
52455291
def __init__(self):
52465292
super().__init__()

0 commit comments

Comments
 (0)