66//
77// ===----------------------------------------------------------------------===//
88//
9- // This file implements lowering patterns from vector.contract to
10- // SVE I8MM operations .
9+ // This file implements lowering patterns from vector.contract to operations
10+ // that map to instructions from the SVE FEAT_I8MM extension .
1111//
12- // ===---
12+ // ===----------------------------------------------------------------------===//
1313
1414#include " mlir/Dialect/Arith/IR/Arith.h"
1515#include " mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -33,10 +33,11 @@ namespace {
3333// Check if the given value is a result of the operation `T` (which must be
3434// sign- or zero- extend) from i8 to i32. Return the value before the extension.
3535template <typename T>
36- inline std::enable_if_t <(std::is_base_of_v<arith::ExtSIOp, T> ||
37- std::is_base_of_v<arith::ExtUIOp, T>),
38- std::optional<Value>>
39- extractExtOperand (Value v, Type i8Ty, Type i32Ty) {
36+ std::optional<Value> extractExtOperand (Value v, Type i8Ty, Type i32Ty) {
37+
38+ static_assert (llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
39+ " Must be instantiated with either sign- or zero- extension op" );
40+
4041 auto extOp = dyn_cast_or_null<T>(v.getDefiningOp ());
4142 if (!extOp)
4243 return {};
@@ -79,6 +80,37 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
7980 }
8081}
8182
83+ // Lower a contraction operation that performs a matrix multiplication
84+ // of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
85+ // for the left-hand side and the right-hand side, respectively,
86+ // yielding a <Mx[N]> 32-bit integer result.
87+ //
88+ // The operands shapes are such that the operands can be evenly split into
89+ // sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
90+ // The intent is that M and N are chosen (by higher level transforms) in such a
91+ // way as to maximise register usage. The main use case we envision as of now is
92+ // MMT4D, thus the RHS operand is expected pre-transposed.
93+ //
94+ // The matrix multiplication is performed by unrolling the usual tiled matrix
95+ // multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
96+ // <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
97+ //
98+ // One way to illustrate the operation is as follows:
99+ //
100+ // RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
101+ // +-----------------------------
102+ // LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
103+ // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
104+ // ... | ... ... ... ...
105+ // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
106+ //
107+ // The RHS operand is unpacked into N/2 values, each representing a sequence of
108+ // VSCALE number of sub-tiles with dimensions <8x2>.
109+ // The LHS operand is initially unpacked into M/2 values, each representing a
110+ // sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
111+ // VSCALE times.
112+ // Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
113+ // correctly computes an entire result sub-tile.
82114class LowerContractionToSVEI8MMPattern
83115 : public OpRewritePattern<vector::ContractionOp> {
84116public:
@@ -90,15 +122,11 @@ class LowerContractionToSVEI8MMPattern
90122 mlir::VectorType lhsType = op.getLhsType ();
91123 mlir::VectorType rhsType = op.getRhsType ();
92124
93- // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
94- // eventually expect from MMT4D. M and N dimensions must be even and at
95- // least 2.
96- if (!lhsType.hasRank () || lhsType.getRank () != 2 || !rhsType.hasRank () ||
97- rhsType.getRank () != 2 )
98- return failure ();
99-
100- if (lhsType.isScalable () || !rhsType.isScalable ())
101- return failure ();
125+ // Check the operands have the expected shape. M and N dimensions must be
126+ // even and at least 2.
127+ if (lhsType.getRank () != 2 || rhsType.getRank () != 2 ||
128+ lhsType.isScalable () || !rhsType.isScalable ())
129+ return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
102130
103131 // M, N, and K are the conventional names for matrix dimensions in the
104132 // context of matrix multiplication.
@@ -108,7 +136,7 @@ class LowerContractionToSVEI8MMPattern
108136
109137 if (lhsType.getDimSize (1 ) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
110138 N % 2 != 0 || !rhsType.getScalableDims ()[0 ])
111- return failure ( );
139+ return rewriter. notifyMatchFailure (op, " non-matching operand shape " );
112140
113141 // Check permutation maps. For now only accept
114142 // lhs: (d0, d1, d2) -> (d0, d2)
@@ -124,28 +152,31 @@ class LowerContractionToSVEI8MMPattern
124152 op.getIndexingMapsArray ()[2 ] !=
125153 AffineMap::getMultiDimMapWithTargets (3 , ArrayRef{0u , 1u },
126154 op.getContext ()))
127- return failure ( );
155+ return rewriter. notifyMatchFailure (op, " non-matching permutation maps " );
128156
129157 // Check iterator types for matrix multiplication.
130158 auto itTypes = op.getIteratorTypesArray ();
131159 if (itTypes.size () != 3 || itTypes[0 ] != vector::IteratorType::parallel ||
132160 itTypes[1 ] != vector::IteratorType::parallel ||
133161 itTypes[2 ] != vector::IteratorType::reduction)
134- return failure ();
162+ return rewriter.notifyMatchFailure (
163+ op, " iterator types do not correspond to matrix multiplication" );
135164
136165 // Check the combining kind is addition.
137166 if (op.getKind () != vector::CombiningKind::ADD)
138- return failure ();
167+ return rewriter.notifyMatchFailure (op,
168+ " combining kind is not an addition" );
139169
140170 // Check the output is a vector of i32 elements.
141- auto outTy = dyn_cast<VectorType>(op.getType ());
171+ auto outTy = dyn_cast<VectorType>(op.getResultType ());
142172 if (!outTy || outTy.getElementType () != rewriter.getI32Type ())
143- return failure ();
173+ return rewriter.notifyMatchFailure (op,
174+ " output type is not a vector of i32" );
144175
145176 // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
146177 // before the extension. All four signed/unsigned combinations for input
147178 // operands are supported, but they are lowered to different operations.
148- // Determina which is the appropriate operation to lower to.
179+ // Determine which is the appropriate operation to lower to.
149180 MMLA mmlaOp = MMLA::Signed;
150181 auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
151182 op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
@@ -155,7 +186,8 @@ class LowerContractionToSVEI8MMPattern
155186 op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
156187 }
157188 if (!maybeLhs)
158- return failure ();
189+ return rewriter.notifyMatchFailure (
190+ op, " LHS is not a sign- or zero- extended i8" );
159191
160192 auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
161193 op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
@@ -169,13 +201,16 @@ class LowerContractionToSVEI8MMPattern
169201 op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
170202 }
171203 if (!maybeRhs)
172- return failure ();
204+ return rewriter.notifyMatchFailure (
205+ op, " RHS is not a sign- or zero- extended i8" );
173206
174207 // One-dimensional vector types for arm_sve.*mmla
175- auto nxv16i8 = VectorType::get (16 , rewriter.getI8Type (), {true });
176- auto nxv4i32 = VectorType::get (4 , rewriter.getI32Type (), {true });
208+ auto nxv16i8 = VectorType::get (/* shape=*/ 16 , rewriter.getI8Type (),
209+ /* scalableDims=*/ {true });
210+ auto nxv4i32 = VectorType::get (/* shape=*/ 4 , rewriter.getI32Type (),
211+ /* scalableDims=*/ {true });
177212
178- // Extract LHS sub-tiles.
213+ // Extract LHS sub-tiles with logicall shape <2x8> .
179214 SmallVector<Value> lhsTile;
180215 for (int64_t i = 0 ; i < M; i += 2 ) {
181216 // Exract two consective rows of the LHS tile.
@@ -199,19 +234,25 @@ class LowerContractionToSVEI8MMPattern
199234 // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
200235 auto RHS = rewriter.create <vector::ShapeCastOp>(
201236 maybeRhs->getLoc (),
202- VectorType::get (8 * N, rewriter.getI8Type (), {true }), *maybeRhs);
237+ VectorType::get (/* shape=*/ 8 * N, rewriter.getI8Type (),
238+ /* scalableDims=*/ {true }),
239+ *maybeRhs);
203240
204- // Extract the RHS sub-tiles.
241+ // Extract the RHS sub-tiles with logical shape <8x[2]> .
205242 SmallVector<Value> rhsTile;
206243 for (int64_t j = 0 ; j < N; j += 2 )
207244 rhsTile.push_back (
208245 rewriter.create <vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8 ));
209246
210247 // Handy types for packing/unpacking of the accumulator tile.
211- auto accRowTy = VectorType::get (N, rewriter.getI32Type (), {true });
212- auto accRowX2Ty = VectorType::get (2 * N, rewriter.getI32Type (), {true });
213- auto accRow64Ty = VectorType::get (N / 2 , rewriter.getI64Type (), {true });
214- auto accRowX264Ty = VectorType::get (N, rewriter.getI64Type (), {true });
248+ auto accRowTy = VectorType::get (/* shape=*/ N, rewriter.getI32Type (),
249+ /* scalableDims=*/ {true });
250+ auto accRowX2Ty = VectorType::get (/* shape=*/ 2 * N, rewriter.getI32Type (),
251+ /* scalableDims=*/ {true });
252+ auto accRow64Ty = VectorType::get (/* shape=*/ N / 2 , rewriter.getI64Type (),
253+ /* scalableDims=*/ {true });
254+ auto accRowX264Ty = VectorType::get (/* shape=*/ N, rewriter.getI64Type (),
255+ /* scalableDims=*/ {true });
215256
216257 // Extract and pack the ACC sub-tiles.
217258 SmallVector<Value> accTile;
0 commit comments