Skip to content

Commit 141202b

Browse files
authored
[TorchToLinalg] Fix integer type handling for aten.mm (#2615)
Despite aten.mm requiring the input and output types match, we still opt to maintain signedness semantics in case later passes try to do any sort of integer type narrowing.
1 parent c011570 commit 141202b

File tree

3 files changed

+78
-13
lines changed

3 files changed

+78
-13
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,24 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
5151
// The compiler cannot crash even if the user wrote an erroneous program!
5252
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
5353
return failure();
54-
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
55-
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
54+
55+
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
56+
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
57+
58+
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
5659
return rewriter.notifyMatchFailure(
5760
op, "expected both operands to aten.mm to be rank 2");
5861
}
5962

63+
ValueTensorType lhsTorchType =
64+
op.getSelf().getType().cast<ValueTensorType>();
65+
ValueTensorType rhsTorchType =
66+
op.getMat2().getType().cast<ValueTensorType>();
67+
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
68+
return rewriter.notifyMatchFailure(
69+
op, "unsupported: aten.mm with different input element types");
70+
}
71+
6072
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
6173
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
6274

@@ -73,16 +85,22 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
7385

7486
Type newResultType = getTypeConverter()->convertType(op.getType());
7587
Type elementType = newResultType.cast<TensorType>().getElementType();
76-
Value initTensor = rewriter.create<tensor::EmptyOp>(
77-
loc, ArrayRef<OpFoldResult>{lhsDim0, rhsDim1}, elementType);
78-
Value c0 = rewriter.create<arith::ConstantOp>(
79-
loc, FloatAttr::get(elementType, 0.0));
80-
Value zeroFill =
81-
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
82-
Value matmul = rewriter
83-
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
84-
ValueRange{lhs, rhs}, zeroFill)
85-
.getResult(0);
88+
Value zeroFill = createZeroInitTensor(
89+
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
90+
91+
Value matmul;
92+
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
93+
if (intType && intType.isUnsigned()) {
94+
matmul = rewriter
95+
.create<linalg::MatmulUnsignedOp>(
96+
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
97+
.getResult(0);
98+
} else {
99+
matmul = rewriter
100+
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
101+
ValueRange{lhs, rhs}, zeroFill)
102+
.getResult(0);
103+
}
86104
// When constructed with just dynamic sizes, EmptyOp will have a result
87105
// type which has all `?`'s for dimensions, which might not be the result
88106
// type of `op`. The constraints on later linalg ops means that the result

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,40 @@ def forward(self, m, v):
225225

226226
@register_test_case(module_factory=lambda: Mv())
227227
def Mv_basic(module, tu: TestUtils):
228-
module.forward(tu.rand(2, 2), tu.rand(2))
228+
module.forward(tu.rand(2, 2), tu.rand(2))
229+
230+
# ==============================================================================
231+
232+
class AtenMmFloatTypes(torch.nn.Module):
233+
234+
@export
235+
@annotate_args([
236+
None,
237+
([-1, -1], torch.float32, True),
238+
([-1, -1], torch.float32, True),
239+
])
240+
def forward(self, a, b):
241+
return torch.ops.aten.mm(a, b)
242+
243+
244+
@register_test_case(module_factory=lambda: AtenMmFloatTypes())
245+
def AtenMmFloatTypes_basic(module, tu: TestUtils):
246+
module.forward(tu.rand(8, 8), tu.rand(8, 8))
247+
248+
# ==============================================================================
249+
250+
class AtenMmIntTypes(torch.nn.Module):
251+
252+
@export
253+
@annotate_args([
254+
None,
255+
([-1, -1], torch.int64, True),
256+
([-1, -1], torch.int64, True),
257+
])
258+
def forward(self, a, b):
259+
return torch.ops.aten.mm(a, b)
260+
261+
262+
@register_test_case(module_factory=lambda: AtenMmIntTypes())
263+
def AtenMmIntTypes_basic(module, tu: TestUtils):
264+
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
4040

4141
// -----
4242

43+
// CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned(
44+
// CHECK: linalg.matmul_unsigned
45+
func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32>
46+
attributes {torch.assume_strict_symbolic_shapes}
47+
{
48+
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32>
49+
return %0 : !torch.vtensor<[?,2],ui32>
50+
}
51+
52+
// -----
53+
4354
// If the operands are missing dtype, we cannot lower it.
4455
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
4556
// expected-error@+1 {{failed to legalize}}

0 commit comments

Comments
 (0)