Skip to content

Commit 03fdf56

Browse files
dan-garveycathyzhyi
authored andcommitted
add aten.add.int lowering in TorchToStd
1 parent 7616d28 commit 03fdf56

File tree

5 files changed

+77
-1
lines changed

5 files changed

+77
-1
lines changed

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from . import argmax
4141
from . import matmul
4242
from . import view
43+
from . import scalar
4344

4445
def _get_argparse():
4546
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

e2e_testing/torchscript/scalar.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
13+
class AddIntModule(torch.nn.Module):
14+
def __init__(self):
15+
super().__init__()
16+
17+
@export
18+
@annotate_args([
19+
None,
20+
([], torch.int64, True),
21+
([], torch.int64, True),
22+
])
23+
def forward(self, lhs, rhs):
24+
return int(lhs)+int(rhs)
25+
26+
27+
@register_test_case(module_factory=lambda: AddIntModule())
28+
def AddIntModule_basic(module, tu: TestUtils):
29+
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))

lib/Conversion/TorchToStd/TorchToStd.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
4444
};
4545
} // namespace
4646

47+
namespace {
48+
class ConvertAtenAddIntOp : public OpConversionPattern<AtenAddIntOp> {
49+
public:
50+
using OpConversionPattern::OpConversionPattern;
51+
LogicalResult
52+
matchAndRewrite(AtenAddIntOp op, OpAdaptor adaptor,
53+
ConversionPatternRewriter &rewriter) const override {
54+
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, adaptor.a(), adaptor.b());
55+
return success();
56+
}
57+
};
58+
} // namespace
59+
4760
namespace {
4861
class ConvertAtenNeIntOp : public OpConversionPattern<AtenNeIntOp> {
4962
public:
@@ -129,6 +142,8 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
129142
target.addIllegalOp<Torch::ConstantIntOp>();
130143
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
131144
context);
145+
target.addIllegalOp<AtenAddIntOp>();
146+
patterns.add<ConvertAtenAddIntOp>(typeConverter, context);
132147
if (failed(applyPartialConversion(getOperation(), target,
133148
std::move(patterns))))
134149
return signalPassFailure();

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
232232
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
233233
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
234234
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
235-
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp>(op)) {
235+
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp>(
236+
op)) {
236237
return getLatticeElement(op->getResult(0)).join(*operands[0]);
237238
}
238239

@@ -426,6 +427,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
426427
return visitNumToTensorOp(numToTensorOp);
427428
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
428429
return visitAtenAddCLikeOp(op, operands);
430+
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
431+
return visitBinaryScalarOp(scalarOp);
429432
}
430433

431434
// Otherwise, this is an unknown operation. Just mark all results as
@@ -528,6 +531,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
528531
ChangeResult
529532
visitAtenEmbeddingOp(AtenEmbeddingOp op,
530533
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
534+
template <typename OpTy> ChangeResult visitBinaryScalarOp(OpTy op);
535+
531536
ChangeResult
532537
visitAtenBmmOp(AtenBmmOp op,
533538
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@@ -587,6 +592,13 @@ static ResultTypeState updateResultTypeState(ValueKnowledge *tensor,
587592
return new_state;
588593
}
589594

595+
static Type getPromotedResultType(ArrayRef<Type> scalarTypes) {
596+
ResultTypeState state = {};
597+
for (const Type &scalarType : scalarTypes)
598+
state = updateResultTypeState(scalarType, state);
599+
return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state));
600+
}
601+
590602
// Returns most generic type Type() if the tensor dtype is unknown.
591603
static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
592604
if (!tensor->dtype)
@@ -1086,6 +1098,14 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
10861098
return getLatticeElement(op.getResult()).join(knowledge);
10871099
}
10881100

1101+
template <typename OpTy>
1102+
ChangeResult TypeAnalyzer::visitBinaryScalarOp(OpTy op) {
1103+
auto knowledge =
1104+
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
1105+
knowledge.dtype = getPromotedResultType({op.a().getType(), op.b().getType()});
1106+
return getLatticeElement(op.getResult()).join(knowledge);
1107+
}
1108+
10891109
// `torch.aten.tensor` get a tensor from a list. Each layer of the list
10901110
// corresponds to one dim of the tensor.
10911111
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {

test/Conversion/TorchToStd/basic.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,14 @@ func @torch.constant.int() -> !torch.int {
7474
%int1 = torch.constant.int 1
7575
return %int1 : !torch.int
7676
}
77+
78+
// CHECK-LABEL: func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
79+
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
80+
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
81+
// CHECK: %[[INT:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
82+
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
83+
// CHECK: return %[[INT:.*]] : !torch.int
84+
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
85+
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
86+
return %0 : !torch.int
87+
}

0 commit comments

Comments
 (0)