Skip to content

Commit 7cf7b91

Browse files
committed
[MLIR][TORCH] Fix tensor literal int elem type to be signless
The element type of tensor literal should be signless when converted to builtin tensor types.
1 parent d6b6c02 commit 7cf7b91

File tree

3 files changed

+86
-7
lines changed

3 files changed

+86
-7
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def forward(self, x):
602602
@register_test_case(module_factory=lambda: ContiguousModule())
603603
def ContiguousModule_basic(module, tu: TestUtils):
604604
module.forward(tu.rand(3, 1))
605-
605+
606606
class TensorToInt(torch.nn.Module):
607607
def __init__(self):
608608
super().__init__()
@@ -699,7 +699,7 @@ def __init__(self):
699699

700700
@export
701701
@annotate_args([
702-
None,
702+
None,
703703
([-1, -1], torch.float32, True),
704704
([-1, -1], torch.float32, True),
705705
([-1, -1], torch.float32, True),
@@ -718,7 +718,7 @@ def __init__(self):
718718

719719
@export
720720
@annotate_args([
721-
None,
721+
None,
722722
([-1, -1], torch.float32, True),
723723
([-1, -1], torch.float32, True),
724724
([-1, -1], torch.float32, True),
@@ -739,7 +739,7 @@ def __init__(self):
739739

740740
@export
741741
@annotate_args([
742-
None,
742+
None,
743743
])
744744

745745
def forward(self):
@@ -756,7 +756,7 @@ def __init__(self):
756756

757757
@export
758758
@annotate_args([
759-
None,
759+
None,
760760
])
761761

762762
def forward(self):
@@ -963,3 +963,38 @@ def forward(self, lhs):
963963
def TModuleRank0_basic(module, tu: TestUtils):
964964
module.forward(torch.tensor(7, dtype=torch.float32))
965965

966+
class TensorLiteralModule(torch.nn.Module):
967+
def __init__(self):
968+
super().__init__()
969+
torch.manual_seed(0)
970+
self.t = torch.randint(-5, 5, (2, 3))
971+
972+
@export
973+
@annotate_args([
974+
None,
975+
])
976+
def forward(self):
977+
return torch.add(self.t, self.t)
978+
979+
@register_test_case(module_factory=lambda: TensorLiteralModule())
980+
def TensorLiteralModule_basic(module, tu: TestUtils):
981+
module.forward()
982+
983+
984+
class TensorOpaqueLiteralModule(torch.nn.Module):
985+
def __init__(self):
986+
super().__init__()
987+
torch.manual_seed(0)
988+
self.t = torch.randint(-5, 5, (256, 1024))
989+
990+
@export
991+
@annotate_args([
992+
None,
993+
])
994+
def forward(self):
995+
return torch.add(self.t, self.t)
996+
997+
@register_test_case(module_factory=lambda: TensorOpaqueLiteralModule())
998+
def TensorOpaqueLiteralModule_basic(module, tu: TestUtils):
999+
module.forward()
1000+

lib/Conversion/TorchToStd/TorchToStd.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,49 @@ class ConvertAtenGtIntOp : public OpConversionPattern<AtenGtIntOp> {
8787
};
8888
} // namespace
8989

90+
// Tensors with integer types need to be converted to signless integer
91+
// element type. All tensors with element types other than integer can reuse
92+
// existing elements attribute.
93+
namespace {
94+
class ConvertTorchTensorLiteralOp
95+
: public OpConversionPattern<ValueTensorLiteralOp> {
96+
public:
97+
using OpConversionPattern<ValueTensorLiteralOp>::OpConversionPattern;
98+
using OpAdaptor = ValueTensorLiteralOp::Adaptor;
99+
LogicalResult
100+
matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor,
101+
ConversionPatternRewriter &rewriter) const override {
102+
MLIRContext *context = op->getContext();
103+
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
104+
Type elemTy = op.valueAttr().getElementType();
105+
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
106+
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
107+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
108+
op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
109+
return APInt(bitWidth, v.getSExtValue());
110+
}));
111+
return success();
112+
}
113+
if (auto elements = op.valueAttr().dyn_cast<OpaqueElementsAttr>()) {
114+
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
115+
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
116+
Type builtinTensorElemTy =
117+
IntegerType::get(context, intType.getIntOrFloatBitWidth());
118+
auto shapedType =
119+
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
120+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
121+
op, OpaqueElementsAttr::get(elements.getDialect(), shapedType,
122+
elements.getValue()));
123+
return success();
124+
}
125+
}
126+
}
127+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.valueAttr());
128+
return success();
129+
}
130+
};
131+
} // namespace
132+
90133
namespace {
91134
template <typename OpTy>
92135
class ConvertTorchConstantOp : public OpConversionPattern<OpTy> {
@@ -133,8 +176,8 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
133176
target.addIllegalOp<AtenGtIntOp>();
134177
patterns.add<ConvertAtenGtIntOp>(typeConverter, context);
135178
target.addIllegalOp<ValueTensorLiteralOp>();
136-
patterns.add<ConvertTorchConstantOp<ValueTensorLiteralOp>>(typeConverter,
137-
context);
179+
patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context);
180+
138181
target.addIllegalOp<ConstantBoolOp>();
139182
patterns.add<ConvertTorchConstantOp<ConstantBoolOp>>(typeConverter,
140183
context);

lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class VerifyLinalgOnTensorsBackendContractPass
6464
// Tensor operations should go through linalg and the tensor dialect.
6565
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
6666
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
67+
target.addDynamicallyLegalDialect<AffineDialect>(opHasLegalTypes);
6768

6869
// AssertOp is used to terminate the program for error guards.
6970
target.addLegalOp<AssertOp>();

0 commit comments

Comments
 (0)