1- // ===- LowerContractionToSMMLAPattern .cpp - Contract to SMMLA -- -*- C++ -*-===//
1+ // ===- LowerContractionToNeonI8MMPattern .cpp - Contract to I8MM -*- C++ -*-===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
88//
9- // This file implements lowering patterns from vector.contract to
10- // arm_neon.intr.smmla
9+ // This file implements lowering patterns from vector.contract to operations
10+ // that map to instructions from the Neon FEAT_I8MM extension.
1111//
12- // ===---
12+ // TODO: There may be opportunities to unify this with a similar pattern
13+ // for SVE. See:
14+ // https://github.com/llvm/llvm-project/issues/145559
15+ // LowerContractionToSVEI8MMPattern.cpp
16+ //
17+ // ===----------------------------------------------------------------------===//
1318
1419#include " mlir/Dialect/Arith/IR/Arith.h"
1520#include " mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -37,12 +42,87 @@ static Type matchContainerType(Type element, Type container) {
3742 return element;
3843}
3944
45+ // Get the operand of a `vector.contract`. This function is intended to abstract
46+ // away from the particular way a value is extended before feeding it into the
47+ // `vector.contract` - via zero-extend or an explicit or implicit sign-extend
48+ // (for implicit sign-extension see `vector.contract` documentation).
49+ //
50+ // The template parameter `Op` indicates the extension operation (explicit or
51+ // implicit) for which we are checking.
52+ //
53+ // Return success only for extensions from `iN` (N <= 8) to `i32`.
54+ template <typename Op>
55+ std::optional<Value> getExtOperand (Value v) {
56+
57+ static_assert (llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
58+ " Must be instantiated with either sign- or zero- extension op" );
59+
60+ // If the operand is not defined by an explicit extend operation of the
61+ // accepted operation type allow for an implicit sign-extension.
62+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp ());
63+ if (!extOp) {
64+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
65+ auto eltTy = cast<VectorType>(v.getType ()).getElementType ();
66+ if (!eltTy.isSignlessInteger () || eltTy.getIntOrFloatBitWidth () > 8 )
67+ return {};
68+ return v;
69+ }
70+ return {};
71+ }
72+
73+ // If the operand is defined by an explicit extend operation of the accepted
74+ // operation type, check it's extended from `iN` (N <= 8) to `i32`.
75+ auto inOp = extOp.getIn ();
76+ auto inTy = dyn_cast<VectorType>(inOp.getType ());
77+ if (!inTy)
78+ return {};
79+ auto inEltTy = inTy.getElementType ();
80+ if (!inEltTy.isSignlessInteger () || inEltTy.getIntOrFloatBitWidth () > 8 )
81+ return {};
82+
83+ auto outTy = dyn_cast<VectorType>(extOp.getType ());
84+ if (!(outTy && outTy.getElementType ().isSignlessInteger (32 )))
85+ return {};
86+
87+ return inOp;
88+ }
89+
90+ // Designate the operation (resp. instruction) used to do sub-tile matrix
91+ // multiplications.
92+ enum class MMLA {
93+ Signed, // smmla
94+ Unsigned, // ummla
95+ Mixed, // usmmla
96+ MixedSwapped // usmmla with LHS and RHS swapped
97+ };
98+
99+ // Create the matrix mulitply and accumulate operation according to `op`.
100+ Value createMMLA (PatternRewriter &rewriter, MMLA op, Location loc,
101+ mlir::Type accType, Value acc, Value lhs, Value rhs) {
102+ switch (op) {
103+ case MMLA::Signed:
104+ return rewriter.createOrFold <arm_neon::SmmlaOp>(loc, accType, acc, lhs,
105+ rhs);
106+ case MMLA::Unsigned:
107+ return rewriter.createOrFold <arm_neon::UmmlaOp>(loc, accType, acc, lhs,
108+ rhs);
109+ case MMLA::Mixed:
110+ return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
111+ rhs);
112+ case MMLA::MixedSwapped:
113+ // The accumulator comes transposed and the result will be transposed
114+ // later, so all we have to do here is swap the operands.
115+ return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
116+ lhs);
117+ }
118+ }
119+
40120// / Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41121// / any vector.contract into multiple smmla instructions with unrolling so long
42122// / as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43123// / = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44124// / necessary, a single smmla instruction is emitted.
45- class LowerContractionToSMMLAPattern
125+ class LowerContractionToNeonI8MMPattern
46126 : public OpRewritePattern<vector::ContractionOp> {
47127public:
48128 using OpRewritePattern::OpRewritePattern;
@@ -88,39 +168,64 @@ class LowerContractionToSMMLAPattern
88168 return failure ();
89169 }
90170
91- // Check two extsi inputs Rhs Lhs for contract.
92- arith::ExtSIOp origLhsExtOp =
93- dyn_cast_or_null<arith::ExtSIOp>(op.getLhs ().getDefiningOp ());
94- arith::ExtSIOp origRhsExtOp =
95- dyn_cast_or_null<arith::ExtSIOp>(op.getRhs ().getDefiningOp ());
96- if (!origLhsExtOp || !origRhsExtOp) {
171+ // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
172+ // values before the extension. All four signed/unsigned combinations for
173+ // input operands are supported, but they are lowered to different
174+ // operations. Determine which is the appropriate operation to lower to.
175+ MMLA mmlaOp = MMLA::Signed;
176+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
177+ if (!maybeLhs) {
178+ mmlaOp = MMLA::Unsigned;
179+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
180+ }
181+ if (!maybeLhs)
97182 return failure ();
183+
184+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
185+ if (maybeRhs) {
186+ if (mmlaOp == MMLA::Unsigned)
187+ mmlaOp = MMLA::Mixed;
188+ } else {
189+ if (mmlaOp == MMLA::Signed)
190+ mmlaOp = MMLA::MixedSwapped;
191+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
98192 }
193+ if (!maybeRhs)
194+ return failure ();
195+
196+ Value origLhs = *maybeLhs;
197+ Value origRhs = *maybeRhs;
99198
100199 // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
101200 // following neon instruction. Check inputs for extsi are <=i8
102- Value extsiLhs;
103- Value extsiRhs;
104- if (auto lhsExtInType =
105- dyn_cast<mlir::VectorType>(origLhsExtOp.getIn ().getType ())) {
201+ Value extLhs;
202+ Value extRhs;
203+ if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType ())) {
106204 if (lhsExtInType.getElementTypeBitWidth () <= 8 ) {
107205 Type targetLhsExtTy =
108206 matchContainerType (rewriter.getI8Type (), lhsExtInType);
109- extsiLhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetLhsExtTy,
110- origLhsExtOp.getIn ());
207+ if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
208+ extLhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetLhsExtTy,
209+ origLhs);
210+ else
211+ extLhs = rewriter.createOrFold <arith::ExtUIOp>(loc, targetLhsExtTy,
212+ origLhs);
111213 }
112214 }
113- if (auto rhsExtInType =
114- dyn_cast<mlir::VectorType>(origRhsExtOp.getIn ().getType ())) {
215+ if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType ())) {
115216 if (rhsExtInType.getElementTypeBitWidth () <= 8 ) {
116217 Type targetRhsExtTy =
117218 matchContainerType (rewriter.getI8Type (), rhsExtInType);
118- extsiRhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetRhsExtTy,
119- origRhsExtOp.getIn ());
219+ if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
220+ extRhs = rewriter.createOrFold <arith::ExtUIOp>(loc, targetRhsExtTy,
221+ origRhs);
222+ else
223+ extRhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetRhsExtTy,
224+ origRhs);
120225 }
121226 }
122227
123- if (!extsiLhs || !extsiRhs ) {
228+ if (!extLhs || !extRhs ) {
124229 return failure ();
125230 }
126231
@@ -155,11 +260,11 @@ class LowerContractionToSMMLAPattern
155260 AffineMap lhsPermutationMap = op.getIndexingMapsArray ()[0 ];
156261 SmallVector<int64_t > lhsOffsets =
157262 applyPermutationMap (lhsPermutationMap, ArrayRef<int64_t >(offsets));
158- Value tiledLhs = extractOperand (extsiLhs , lhsPermutationMap, lhsOffsets);
263+ Value tiledLhs = extractOperand (extLhs , lhsPermutationMap, lhsOffsets);
159264 AffineMap rhsPermutationMap = op.getIndexingMapsArray ()[1 ];
160265 SmallVector<int64_t > rhsOffsets =
161266 applyPermutationMap (rhsPermutationMap, ArrayRef<int64_t >(offsets));
162- Value tiledRhs = extractOperand (extsiRhs , rhsPermutationMap, rhsOffsets);
267+ Value tiledRhs = extractOperand (extRhs , rhsPermutationMap, rhsOffsets);
163268 AffineMap accPermutationMap = op.getIndexingMapsArray ()[2 ];
164269 SmallVector<int64_t > accOffsets =
165270 applyPermutationMap (accPermutationMap, ArrayRef<int64_t >(offsets));
@@ -191,6 +296,13 @@ class LowerContractionToSMMLAPattern
191296 tiledAcc = expandForSMMLA (tiledAcc, outputExpandedType);
192297 }
193298
299+ // Transpose ACC if doing signed by unsigned multiplication, because we're
300+ // using the instruction for unsigned by signed multiplication with
301+ // reversed operands.
302+ if (mmlaOp == MMLA::MixedSwapped)
303+ tiledAcc = rewriter.create <vector::TransposeOp>(
304+ loc, tiledAcc, ArrayRef<int64_t >({1 , 0 }));
305+
194306 // Collapse tiled operands to 1D vectors required by smmla intrinsic
195307 auto collapsedInputType =
196308 VectorType::get (inputExpandedType.getNumElements (), inputElementType);
@@ -211,15 +323,21 @@ class LowerContractionToSMMLAPattern
211323 }
212324
213325 // Insert contract op
214- kAcc = rewriter.createOrFold <arm_neon::SmmlaOp>(
215- op.getLoc (), collapsedRes.getType (), collapsedRes, collapsedLhs,
216- collapsedRhs);
326+ kAcc = createMMLA (rewriter, mmlaOp, op.getLoc (), collapsedRes.getType (),
327+ collapsedRes, collapsedLhs, collapsedRhs);
217328
218329 // Reshape output back to 2D
219330 Value tiledRes = rewriter.createOrFold <vector::ShapeCastOp>(
220331 kAcc .getLoc (), tiledAcc.getType (), kAcc );
221332
222- // With vecmat, only one row of tiled ACC can be inserted into file result
333+ // Because of the reversed operands the result is obtained transposed.
334+ // Transpose it back,
335+ if (mmlaOp == MMLA::MixedSwapped)
336+ tiledRes = rewriter.create <vector::TransposeOp>(
337+ loc, tiledRes, ArrayRef<int64_t >({1 , 0 }));
338+
339+ // With vecmat, only one row of tiled ACC can be inserted into the final
340+ // result
223341 if (isVecmat) {
224342 tiledRes = rewriter.createOrFold <vector::ExtractOp>(loc, tiledRes, 0 );
225343 }
@@ -239,8 +357,8 @@ class LowerContractionToSMMLAPattern
239357
240358} // namespace
241359
242- void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns (
360+ void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns (
243361 RewritePatternSet &patterns) {
244362 MLIRContext *context = patterns.getContext ();
245- patterns.add <LowerContractionToSMMLAPattern >(context, /* benefit=*/ 2 );
363+ patterns.add <LowerContractionToNeonI8MMPattern >(context, /* benefit=*/ 2 );
246364}
0 commit comments