Skip to content

Commit 5eed562

Browse files
xndcnsilvasean
authored andcommitted
add aten.sub.int/aten.mul.int lowering in TorchToStd
1 parent d8ba681 commit 5eed562

File tree

4 files changed

+82
-15
lines changed

4 files changed

+82
-15
lines changed

e2e_testing/torchscript/scalar.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,41 @@ def __init__(self):
2323
def forward(self, lhs, rhs):
2424
return int(lhs)+int(rhs)
2525

26+
class SubIntModule(torch.nn.Module):
27+
def __init__(self):
28+
super().__init__()
29+
30+
@export
31+
@annotate_args([
32+
None,
33+
([], torch.int64, True),
34+
([], torch.int64, True),
35+
])
36+
def forward(self, lhs, rhs):
37+
return int(lhs)-int(rhs)
38+
39+
class MulIntModule(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
@export
44+
@annotate_args([
45+
None,
46+
([], torch.int64, True),
47+
([], torch.int64, True),
48+
])
49+
def forward(self, lhs, rhs):
50+
return int(lhs)*int(rhs)
51+
2652

2753
@register_test_case(module_factory=lambda: AddIntModule())
2854
def AddIntModule_basic(module, tu: TestUtils):
2955
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
56+
57+
@register_test_case(module_factory=lambda: SubIntModule())
58+
def SubIntModule_basic(module, tu: TestUtils):
59+
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
60+
61+
@register_test_case(module_factory=lambda: MulIntModule())
62+
def MulIntModule_basic(module, tu: TestUtils):
63+
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))

lib/Conversion/TorchToStd/TorchToStd.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
4545
} // namespace
4646

4747
namespace {
48-
class ConvertAtenAddIntOp : public OpConversionPattern<AtenAddIntOp> {
48+
template <typename AtenOp, typename BinOp>
49+
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
4950
public:
50-
using OpConversionPattern::OpConversionPattern;
51+
using OpConversionPattern<AtenOp>::OpConversionPattern;
5152
LogicalResult
52-
matchAndRewrite(AtenAddIntOp op, OpAdaptor adaptor,
53+
matchAndRewrite(AtenOp op,
54+
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
5355
ConversionPatternRewriter &rewriter) const override {
54-
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, adaptor.a(), adaptor.b());
56+
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.a(), adaptor.b());
5557
return success();
5658
}
5759
};
@@ -142,8 +144,14 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
142144
target.addIllegalOp<Torch::ConstantIntOp>();
143145
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
144146
context);
145-
target.addIllegalOp<AtenAddIntOp>();
146-
patterns.add<ConvertAtenAddIntOp>(typeConverter, context);
147+
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
148+
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
149+
typeConverter, context);
150+
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
151+
typeConverter, context);
152+
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
153+
typeConverter, context);
154+
147155
if (failed(applyPartialConversion(getOperation(), target,
148156
std::move(patterns))))
149157
return signalPassFailure();

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
471471
return visitNumToTensorOp(numToTensorOp);
472472
} else if (isa<AtenAddcmulOp, AtenAddcdivOp>(op)) {
473473
return visitAtenAddCLikeOp(op, operands);
474-
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
475-
return visitBinaryScalarOp(scalarOp);
474+
} else if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>(op)) {
475+
return visitBinaryScalarOp(op, operands);
476476
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
477477
return visitAtenNllLossForwardOp(nllForwardOp, operands);
478478
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
@@ -590,7 +590,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
590590
ChangeResult
591591
visitAtenEmbeddingOp(AtenEmbeddingOp op,
592592
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
593-
template <typename OpTy> ChangeResult visitBinaryScalarOp(OpTy op);
593+
ChangeResult
594+
visitBinaryScalarOp(Operation *op,
595+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
594596

595597
ChangeResult
596598
visitAtenBmmOp(AtenBmmOp op,
@@ -1276,12 +1278,13 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
12761278
return getLatticeElement(op.getResult()).join(knowledge);
12771279
}
12781280

1279-
template <typename OpTy>
1280-
ChangeResult TypeAnalyzer::visitBinaryScalarOp(OpTy op) {
1281+
ChangeResult TypeAnalyzer::visitBinaryScalarOp(
1282+
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
12811283
auto knowledge =
1282-
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
1283-
knowledge.dtype = getPromotedResultType({op.a().getType(), op.b().getType()});
1284-
return getLatticeElement(op.getResult()).join(knowledge);
1284+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1285+
knowledge.dtype = getPromotedResultType(
1286+
{op->getOperand(0).getType(), op->getOperand(1).getType()});
1287+
return getLatticeElement(op->getResult(0)).join(knowledge);
12851288
}
12861289

12871290
// `torch.aten.tensor` get a tensor from a list. Each layer of the list

test/Conversion/TorchToStd/basic.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,28 @@ func @torch.constant.int() -> !torch.int {
8282
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
8383
// CHECK: return %[[INT:.*]] : !torch.int
8484
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
85-
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
85+
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
86+
return %0 : !torch.int
87+
}
88+
89+
// CHECK-LABEL: func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
90+
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
91+
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
92+
// CHECK: %[[INT:.*]] = arith.subi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
93+
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
94+
// CHECK: return %[[INT:.*]] : !torch.int
95+
func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
96+
%0 = torch.aten.sub.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
97+
return %0 : !torch.int
98+
}
99+
100+
// CHECK-LABEL: func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
101+
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
102+
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
103+
// CHECK: %[[INT:.*]] = arith.muli %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
104+
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
105+
// CHECK: return %[[INT:.*]] : !torch.int
106+
func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
107+
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
86108
return %0 : !torch.int
87109
}

0 commit comments

Comments
 (0)