@@ -51,12 +51,24 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
5151 // The compiler cannot crash even if the user wrote an erroneous program!
5252 if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
5353 return failure ();
54- if (lhs.getType ().cast <RankedTensorType>().getRank () != 2 ||
55- rhs.getType ().cast <RankedTensorType>().getRank () != 2 ) {
54+
55+ RankedTensorType lhsType = lhs.getType ().cast <RankedTensorType>();
56+ RankedTensorType rhsType = rhs.getType ().cast <RankedTensorType>();
57+
58+ if (lhsType.getRank () != 2 || rhsType.getRank () != 2 ) {
5659 return rewriter.notifyMatchFailure (
5760 op, " expected both operands to aten.mm to be rank 2" );
5861 }
5962
63+ ValueTensorType lhsTorchType =
64+ op.getSelf ().getType ().cast <ValueTensorType>();
65+ ValueTensorType rhsTorchType =
66+ op.getMat2 ().getType ().cast <ValueTensorType>();
67+ if (lhsTorchType.getDtype () != rhsTorchType.getDtype ()) {
68+ return rewriter.notifyMatchFailure (
69+ op, " unsupported: aten.mm with different input element types" );
70+ }
71+
6072 Value lhsDim0 = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
6173 Value rhsDim1 = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
6274
@@ -73,16 +85,22 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
7385
7486 Type newResultType = getTypeConverter ()->convertType (op.getType ());
7587 Type elementType = newResultType.cast <TensorType>().getElementType ();
76- Value initTensor = rewriter.create <tensor::EmptyOp>(
77- loc, ArrayRef<OpFoldResult>{lhsDim0, rhsDim1}, elementType);
78- Value c0 = rewriter.create <arith::ConstantOp>(
79- loc, FloatAttr::get (elementType, 0.0 ));
80- Value zeroFill =
81- rewriter.create <linalg::FillOp>(loc, c0, initTensor).getResult (0 );
82- Value matmul = rewriter
83- .create <linalg::MatmulOp>(loc, zeroFill.getType (),
84- ValueRange{lhs, rhs}, zeroFill)
85- .getResult (0 );
88+ Value zeroFill = createZeroInitTensor (
89+ rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
90+
91+ Value matmul;
92+ auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype ());
93+ if (intType && intType.isUnsigned ()) {
94+ matmul = rewriter
95+ .create <linalg::MatmulUnsignedOp>(
96+ loc, zeroFill.getType (), ValueRange{lhs, rhs}, zeroFill)
97+ .getResult (0 );
98+ } else {
99+ matmul = rewriter
100+ .create <linalg::MatmulOp>(loc, zeroFill.getType (),
101+ ValueRange{lhs, rhs}, zeroFill)
102+ .getResult (0 );
103+ }
86104 // When constructed with just dynamic sizes, EmptyOp will have a result
87105 // type which has all `?`'s for dimensions, which might not be the result
88106 // type of `op`. The constraints on later linalg ops means that the result
0 commit comments