@@ -37,6 +37,81 @@ static Type matchContainerType(Type element, Type container) {
3737 return element;
3838}
3939
40+ // Get the operand of a `vector.contract`. This function is intended to abstract
41+ // away from the particular way a value is extended before feeding it into the
42+ // `vector.contract` - via zero-extend or an explicit or implicit sign-extend
43+ // (for implicit sign-extension see `vector.contract` documentation).
44+ //
45+ // The template parameter `Op` indicates the extension operation (explicit or
46+ // implicit) for which we are checking.
47+ //
48+ // Return success only for extensions from `iN` (N <= 8) to `i32`.
49+ template <typename Op>
50+ std::optional<Value> getExtOperand (Value v) {
51+
52+ static_assert (llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
53+ " Must be instantiated with either sign- or zero- extension op" );
54+
55+ // If the operand is not defined by an explicit extend operation of the
56+ // accepted operation type allow for an implicit sign-extension.
57+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp ());
58+ if (!extOp) {
59+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
60+ auto eltTy = cast<VectorType>(v.getType ()).getElementType ();
61+ if (!eltTy.isSignlessInteger () || eltTy.getIntOrFloatBitWidth () > 8 )
62+ return {};
63+ return v;
64+ }
65+ return {};
66+ }
67+
68+ // If the operand is defined by an explicit extend operation of the accepted
69+ // operation type, check it's extended from `iN` (N <= 8) to `i32`.
70+ auto inOp = extOp.getIn ();
71+ auto inTy = dyn_cast<VectorType>(inOp.getType ());
72+ if (!inTy)
73+ return {};
74+ auto inEltTy = inTy.getElementType ();
75+ if (!inEltTy.isSignlessInteger () || inEltTy.getIntOrFloatBitWidth () > 8 )
76+ return {};
77+
78+ auto outTy = dyn_cast<VectorType>(extOp.getType ());
79+ if (!(outTy && outTy.getElementType ().isSignlessInteger (32 )))
80+ return {};
81+
82+ return inOp;
83+ }
84+
85+ // Designate the operation (resp. instruction) used to do sub-tile matrix
86+ // multiplications.
87+ enum class MMLA {
88+ Signed, // smmla
89+ Unsigned, // ummla
90+ Mixed, // usmmla
91+ MixedSwapped // usmmla with LHS and RHS swapped
92+ };
93+
94+ // Create the matrix mulitply and accumulate operation according to `op`.
95+ Value createMMLA (PatternRewriter &rewriter, MMLA op, Location loc,
96+ mlir::Type accType, Value acc, Value lhs, Value rhs) {
97+ switch (op) {
98+ case MMLA::Signed:
99+ return rewriter.createOrFold <arm_neon::SmmlaOp>(loc, accType, acc, lhs,
100+ rhs);
101+ case MMLA::Unsigned:
102+ return rewriter.createOrFold <arm_neon::UmmlaOp>(loc, accType, acc, lhs,
103+ rhs);
104+ case MMLA::Mixed:
105+ return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
106+ rhs);
107+ case MMLA::MixedSwapped:
108+ // The accumulator comes transposed and the result will be transposed
109+ // later, so all we have to do here is swap the operands.
110+ return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
111+ lhs);
112+ }
113+ }
114+
40115// / Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41116// / any vector.contract into multiple smmla instructions with unrolling so long
42117// / as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -88,39 +163,64 @@ class LowerContractionToSMMLAPattern
88163 return failure ();
89164 }
90165
91- // Check two extsi inputs Rhs Lhs for contract.
92- arith::ExtSIOp origLhsExtOp =
93- dyn_cast_or_null<arith::ExtSIOp>(op.getLhs ().getDefiningOp ());
94- arith::ExtSIOp origRhsExtOp =
95- dyn_cast_or_null<arith::ExtSIOp>(op.getRhs ().getDefiningOp ());
96- if (!origLhsExtOp || !origRhsExtOp) {
166+ // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
167+ // values before the extension. All four signed/unsigned combinations for
168+ // input operands are supported, but they are lowered to different
169+ // operations. Determine which is the appropriate operation to lower to.
170+ MMLA mmlaOp = MMLA::Signed;
171+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
172+ if (!maybeLhs) {
173+ mmlaOp = MMLA::Unsigned;
174+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
175+ }
176+ if (!maybeLhs)
97177 return failure ();
178+
179+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
180+ if (maybeRhs) {
181+ if (mmlaOp == MMLA::Unsigned)
182+ mmlaOp = MMLA::Mixed;
183+ } else {
184+ if (mmlaOp == MMLA::Signed)
185+ mmlaOp = MMLA::MixedSwapped;
186+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
98187 }
188+ if (!maybeRhs)
189+ return failure ();
190+
191+ Value origLhs = *maybeLhs;
192+ Value origRhs = *maybeRhs;
99193
100194 // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
101195 // following neon instruction. Check inputs for extsi are <=i8
102- Value extsiLhs;
103- Value extsiRhs;
104- if (auto lhsExtInType =
105- dyn_cast<mlir::VectorType>(origLhsExtOp.getIn ().getType ())) {
196+ Value extLhs;
197+ Value extRhs;
198+ if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType ())) {
106199 if (lhsExtInType.getElementTypeBitWidth () <= 8 ) {
107200 Type targetLhsExtTy =
108201 matchContainerType (rewriter.getI8Type (), lhsExtInType);
109- extsiLhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetLhsExtTy,
110- origLhsExtOp.getIn ());
202+ if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
203+ extLhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetLhsExtTy,
204+ origLhs);
205+ else
206+ extLhs = rewriter.createOrFold <arith::ExtUIOp>(loc, targetLhsExtTy,
207+ origLhs);
111208 }
112209 }
113- if (auto rhsExtInType =
114- dyn_cast<mlir::VectorType>(origRhsExtOp.getIn ().getType ())) {
210+ if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType ())) {
115211 if (rhsExtInType.getElementTypeBitWidth () <= 8 ) {
116212 Type targetRhsExtTy =
117213 matchContainerType (rewriter.getI8Type (), rhsExtInType);
118- extsiRhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetRhsExtTy,
119- origRhsExtOp.getIn ());
214+ if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
215+ extRhs = rewriter.createOrFold <arith::ExtUIOp>(loc, targetRhsExtTy,
216+ origRhs);
217+ else
218+ extRhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetRhsExtTy,
219+ origRhs);
120220 }
121221 }
122222
123- if (!extsiLhs || !extsiRhs ) {
223+ if (!extLhs || !extRhs ) {
124224 return failure ();
125225 }
126226
@@ -155,11 +255,11 @@ class LowerContractionToSMMLAPattern
155255 AffineMap lhsPermutationMap = op.getIndexingMapsArray ()[0 ];
156256 SmallVector<int64_t > lhsOffsets =
157257 applyPermutationMap (lhsPermutationMap, ArrayRef<int64_t >(offsets));
158- Value tiledLhs = extractOperand (extsiLhs , lhsPermutationMap, lhsOffsets);
258+ Value tiledLhs = extractOperand (extLhs , lhsPermutationMap, lhsOffsets);
159259 AffineMap rhsPermutationMap = op.getIndexingMapsArray ()[1 ];
160260 SmallVector<int64_t > rhsOffsets =
161261 applyPermutationMap (rhsPermutationMap, ArrayRef<int64_t >(offsets));
162- Value tiledRhs = extractOperand (extsiRhs , rhsPermutationMap, rhsOffsets);
262+ Value tiledRhs = extractOperand (extRhs , rhsPermutationMap, rhsOffsets);
163263 AffineMap accPermutationMap = op.getIndexingMapsArray ()[2 ];
164264 SmallVector<int64_t > accOffsets =
165265 applyPermutationMap (accPermutationMap, ArrayRef<int64_t >(offsets));
@@ -191,6 +291,13 @@ class LowerContractionToSMMLAPattern
191291 tiledAcc = expandForSMMLA (tiledAcc, outputExpandedType);
192292 }
193293
294+ // Transpose ACC if doing signed by unsigned multiplication, because we're
295+ // using the instruction for unsigned by signed multiplication with
296+ // reversed operands.
297+ if (mmlaOp == MMLA::MixedSwapped)
298+ tiledAcc = rewriter.create <vector::TransposeOp>(
299+ loc, tiledAcc, ArrayRef<int64_t >({1 , 0 }));
300+
194301 // Collapse tiled operands to 1D vectors required by smmla intrinsic
195302 auto collapsedInputType =
196303 VectorType::get (inputExpandedType.getNumElements (), inputElementType);
@@ -211,15 +318,21 @@ class LowerContractionToSMMLAPattern
211318 }
212319
213320 // Insert contract op
214- kAcc = rewriter.createOrFold <arm_neon::SmmlaOp>(
215- op.getLoc (), collapsedRes.getType (), collapsedRes, collapsedLhs,
216- collapsedRhs);
321+ kAcc = createMMLA (rewriter, mmlaOp, op.getLoc (), collapsedRes.getType (),
322+ collapsedRes, collapsedLhs, collapsedRhs);
217323
218324 // Reshape output back to 2D
219325 Value tiledRes = rewriter.createOrFold <vector::ShapeCastOp>(
220326 kAcc .getLoc (), tiledAcc.getType (), kAcc );
221327
222- // With vecmat, only one row of tiled ACC can be inserted into file result
328+ // Because of the reversed operands the result is obtained transposed.
329+ // Transpose it back,
330+ if (mmlaOp == MMLA::MixedSwapped)
331+ tiledRes = rewriter.create <vector::TransposeOp>(
332+ loc, tiledRes, ArrayRef<int64_t >({1 , 0 }));
333+
334+ // With vecmat, only one row of tiled ACC can be inserted into the final
335+ // result
223336 if (isVecmat) {
224337 tiledRes = rewriter.createOrFold <vector::ExtractOp>(loc, tiledRes, 0 );
225338 }
0 commit comments