Skip to content

Commit 8f97d89

Browse files
[MLIR][AArch64] Lower vector.contract with mixed signend/unsigned arguments to Neon FEAT_I8MM
1 parent 66d6964 commit 8f97d89

File tree

2 files changed

+203
-36
lines changed

2 files changed

+203
-36
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 136 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,81 @@ static Type matchContainerType(Type element, Type container) {
3737
return element;
3838
}
3939

40+
// Get the operand of a `vector.contract`. This function is intended to abstract
41+
// away from the particular way a value is extended before feeding it into the
42+
// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
43+
// (for implicit sign-extension see `vector.contract` documentation).
44+
//
45+
// The template parameter `Op` indicates the extension operation (explicit or
46+
// implicit) for which we are checking.
47+
//
48+
// Return success only for extensions from `iN` (N <= 8) to `i32`.
49+
template <typename Op>
50+
std::optional<Value> getExtOperand(Value v) {
51+
52+
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
53+
"Must be instantiated with either sign- or zero- extension op");
54+
55+
// If the operand is not defined by an explicit extend operation of the
56+
// accepted operation type allow for an implicit sign-extension.
57+
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
58+
if (!extOp) {
59+
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
60+
auto eltTy = cast<VectorType>(v.getType()).getElementType();
61+
if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
62+
return {};
63+
return v;
64+
}
65+
return {};
66+
}
67+
68+
// If the operand is defined by an explicit extend operation of the accepted
69+
// operation type, check it's extended from `iN` (N <= 8) to `i32`.
70+
auto inOp = extOp.getIn();
71+
auto inTy = dyn_cast<VectorType>(inOp.getType());
72+
if (!inTy)
73+
return {};
74+
auto inEltTy = inTy.getElementType();
75+
if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
76+
return {};
77+
78+
auto outTy = dyn_cast<VectorType>(extOp.getType());
79+
if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
80+
return {};
81+
82+
return inOp;
83+
}
84+
85+
// Designate the operation (resp. instruction) used to do sub-tile matrix
86+
// multiplications.
87+
enum class MMLA {
88+
Signed, // smmla
89+
Unsigned, // ummla
90+
Mixed, // usmmla
91+
MixedSwapped // usmmla with LHS and RHS swapped
92+
};
93+
94+
// Create the matrix mulitply and accumulate operation according to `op`.
95+
Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
96+
mlir::Type accType, Value acc, Value lhs, Value rhs) {
97+
switch (op) {
98+
case MMLA::Signed:
99+
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
100+
rhs);
101+
case MMLA::Unsigned:
102+
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
103+
rhs);
104+
case MMLA::Mixed:
105+
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
106+
rhs);
107+
case MMLA::MixedSwapped:
108+
// The accumulator comes transposed and the result will be transposed
109+
// later, so all we have to do here is swap the operands.
110+
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
111+
lhs);
112+
}
113+
}
114+
40115
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41116
/// any vector.contract into multiple smmla instructions with unrolling so long
42117
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -88,39 +163,64 @@ class LowerContractionToSMMLAPattern
88163
return failure();
89164
}
90165

91-
// Check two extsi inputs Rhs Lhs for contract.
92-
arith::ExtSIOp origLhsExtOp =
93-
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
94-
arith::ExtSIOp origRhsExtOp =
95-
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
96-
if (!origLhsExtOp || !origRhsExtOp) {
166+
// Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
167+
// values before the extension. All four signed/unsigned combinations for
168+
// input operands are supported, but they are lowered to different
169+
// operations. Determine which is the appropriate operation to lower to.
170+
MMLA mmlaOp = MMLA::Signed;
171+
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
172+
if (!maybeLhs) {
173+
mmlaOp = MMLA::Unsigned;
174+
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
175+
}
176+
if (!maybeLhs)
97177
return failure();
178+
179+
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
180+
if (maybeRhs) {
181+
if (mmlaOp == MMLA::Unsigned)
182+
mmlaOp = MMLA::Mixed;
183+
} else {
184+
if (mmlaOp == MMLA::Signed)
185+
mmlaOp = MMLA::MixedSwapped;
186+
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
98187
}
188+
if (!maybeRhs)
189+
return failure();
190+
191+
Value origLhs = *maybeLhs;
192+
Value origRhs = *maybeRhs;
99193

100194
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
101195
// following neon instruction. Check inputs for extsi are <=i8
102-
Value extsiLhs;
103-
Value extsiRhs;
104-
if (auto lhsExtInType =
105-
dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
196+
Value extLhs;
197+
Value extRhs;
198+
if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
106199
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
107200
Type targetLhsExtTy =
108201
matchContainerType(rewriter.getI8Type(), lhsExtInType);
109-
extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
110-
origLhsExtOp.getIn());
202+
if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
203+
extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
204+
origLhs);
205+
else
206+
extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
207+
origLhs);
111208
}
112209
}
113-
if (auto rhsExtInType =
114-
dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
210+
if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
115211
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
116212
Type targetRhsExtTy =
117213
matchContainerType(rewriter.getI8Type(), rhsExtInType);
118-
extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
119-
origRhsExtOp.getIn());
214+
if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
215+
extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
216+
origRhs);
217+
else
218+
extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
219+
origRhs);
120220
}
121221
}
122222

123-
if (!extsiLhs || !extsiRhs) {
223+
if (!extLhs || !extRhs) {
124224
return failure();
125225
}
126226

@@ -155,11 +255,11 @@ class LowerContractionToSMMLAPattern
155255
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
156256
SmallVector<int64_t> lhsOffsets =
157257
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
158-
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
258+
Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
159259
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
160260
SmallVector<int64_t> rhsOffsets =
161261
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
162-
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
262+
Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
163263
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
164264
SmallVector<int64_t> accOffsets =
165265
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
@@ -191,6 +291,13 @@ class LowerContractionToSMMLAPattern
191291
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
192292
}
193293

294+
// Transpose ACC if doing signed by unsigned multiplication, because we're
295+
// using the instruction for unsigned by signed multiplication with
296+
// reversed operands.
297+
if (mmlaOp == MMLA::MixedSwapped)
298+
tiledAcc = rewriter.create<vector::TransposeOp>(
299+
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
300+
194301
// Collapse tiled operands to 1D vectors required by smmla intrinsic
195302
auto collapsedInputType =
196303
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
@@ -211,15 +318,21 @@ class LowerContractionToSMMLAPattern
211318
}
212319

213320
// Insert contract op
214-
kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
215-
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
216-
collapsedRhs);
321+
kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
322+
collapsedRes, collapsedLhs, collapsedRhs);
217323

218324
// Reshape output back to 2D
219325
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
220326
kAcc.getLoc(), tiledAcc.getType(), kAcc);
221327

222-
// With vecmat, only one row of tiled ACC can be inserted into file result
328+
// Because of the reversed operands the result is obtained transposed.
329+
// Transpose it back,
330+
if (mmlaOp == MMLA::MixedSwapped)
331+
tiledRes = rewriter.create<vector::TransposeOp>(
332+
loc, tiledRes, ArrayRef<int64_t>({1, 0}));
333+
334+
// With vecmat, only one row of tiled ACC can be inserted into the final
335+
// result
223336
if (isVecmat) {
224337
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
225338
}

mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,28 @@ func.func @vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4
1717

1818
// -----
1919

20-
// CHECK-LABEL: vector_arm_neon_same_types
21-
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32>
22-
// CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
23-
// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8>
24-
// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
25-
// CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32>
26-
// CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32>
27-
func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
20+
// CHECK-LABEL: vector_arm_neon_implicit_extsi
21+
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
22+
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
23+
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
24+
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
25+
// CHECK: %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
26+
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
27+
func.func @vector_arm_neon_implicit_extsi(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
28+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi8>, vector<2x8xi8> into vector<2x2xi32>
29+
return %res : vector<2x2xi32>
30+
}
31+
32+
// -----
33+
34+
// CHECK-LABEL: vector_arm_neon_signed_signed
35+
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
36+
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
37+
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
38+
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
39+
// CHECK: %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
40+
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
41+
func.func @vector_arm_neon_signed_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
2842
%lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
2943
%rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
3044
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
@@ -33,11 +47,51 @@ func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>
3347

3448
// -----
3549

36-
// CHECK-LABEL: vector_arm_neon_without_extsi
37-
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32>
38-
// CHECK-DAG: %[[D0:.*]] = vector.contract
39-
func.func @vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
40-
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
50+
// CHECK-LABEL: vector_arm_neon_unsigned_signed
51+
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
52+
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
53+
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
54+
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
55+
// CHECK: %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
56+
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
57+
func.func @vector_arm_neon_unsigned_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
58+
%lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
59+
%rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
60+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
61+
return %res : vector<2x2xi32>
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: vector_arm_neon_unsigned_unsigned
67+
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
68+
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
69+
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
70+
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
71+
// CHECK: %[[M:.+]] = arm_neon.intr.ummla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
72+
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
73+
func.func @vector_arm_neon_unsigned_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
74+
%lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
75+
%rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
76+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
77+
return %res : vector<2x2xi32>
78+
}
79+
80+
// -----
81+
82+
// CHECK-LABEL: vector_arm_neon_signed_unsigned
83+
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
84+
// CHECK: %[[ACC_T:.+]] = vector.transpose %[[ACC]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
85+
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
86+
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
87+
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC_T]] : vector<2x2xi32> to vector<4xi32>
88+
// CHECK: %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[R]], %[[L]] : vector<16xi8> to vector<4xi32>
89+
// CHECK: %[[OUT_T:.+]] = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
90+
// CHECK: %{{.+}} = vector.transpose %[[OUT_T]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
91+
func.func @vector_arm_neon_signed_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
92+
%lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
93+
%rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
94+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
4195
return %res : vector<2x2xi32>
4296
}
4397

0 commit comments

Comments
 (0)