@@ -30,23 +30,26 @@ using namespace mlir;
3030using namespace mlir ::arm_sve;
3131
3232namespace {
33- // Get the LHS or RHS side operand of a vector contract. Handle two cases
34- // * if the operand is a sign- or zero- extend operation of type `T` from i8
35- // to i32, return the value before the extension, otherwise
36- // * if the operand is of i8 type and the operation is sign-extend, return the
37- // operand itself.
33+ // Get the operand of a `vector.contract`. This function is intended to abstract
34+ // away from the particular way a value is extended before feeding it into the
35+ // `vector.contract` - via zero-extend or an explicit or implicit sign-extend
36+ // (for implicit sign-extension see `vector.contract` documentation).
3837//
39- // This way we handle both explicit sign- or zero- extension or implicit
40- // sign-extension.
41- template <typename T>
38+ // The template parameter `Op` indicates the extension operation (explicir or
39+ // implicit) for which we are checking.
40+ //
41+ // Return success only for extensions from `i8` to `i32`.
42+ template <typename Op>
4243std::optional<Value> getExtOperand (Value v, Type i8Ty, Type i32Ty) {
4344
44- static_assert (llvm::is_one_of<T , arith::ExtSIOp, arith::ExtUIOp>::value,
45+ static_assert (llvm::is_one_of<Op , arith::ExtSIOp, arith::ExtUIOp>::value,
4546 " Must be instantiated with either sign- or zero- extension op" );
4647
47- auto extOp = dyn_cast_or_null<T>(v.getDefiningOp ());
48+ // If the operand is not defined by an explicit extend operation of the
49+ // accepted operation type allow for an implicit sign-extension.
50+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp ());
4851 if (!extOp) {
49- if constexpr (std::is_same<T , arith::ExtSIOp>::value) {
52+ if constexpr (std::is_same<Op , arith::ExtSIOp>::value) {
5053 auto vTy = cast<VectorType>(v.getType ());
5154 if (vTy.getElementType () != i8Ty)
5255 return {};
@@ -55,6 +58,8 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
5558 return {};
5659 }
5760
61+ // If the operand is defined by an explicit extend operation of the accepted
62+ // operation type, check it's extented from `i8` to `i32`.
5863 auto inOp = extOp.getIn ();
5964 auto inTy = dyn_cast<VectorType>(inOp.getType ());
6065 if (!inTy || inTy.getElementType () != i8Ty)
@@ -93,37 +98,38 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
9398 }
9499}
95100
96- // Lower a contraction operation that performs a matrix multiplication
97- // of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
98- // for the left-hand side and the right-hand side, respectively,
99- // yielding a <Mx[N]> 32-bit integer result.
100- //
101- // The operands shapes are such that the operands can be evenly split into
102- // sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
103- // The intent is that M and N are chosen (by higher level transforms) in such a
104- // way as to maximise register usage. The main use case we envision as of now is
105- // MMT4D, thus the RHS operand is expected pre-transposed.
106- //
107- // The matrix multiplication is performed by unrolling the usual tiled matrix
108- // multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
109- // <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
110- //
111- // One way to illustrate the operation is as follows:
112- //
113- // RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
114- // +-----------------------------
115- // LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
116- // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
117- // ... | ... ... ... ...
118- // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
119- //
120- // The RHS operand is unpacked into N/2 values, each representing a sequence of
121- // VSCALE number of sub-tiles with dimensions <8x2>.
122- // The LHS operand is initially unpacked into M/2 values, each representing a
123- // sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
124- // VSCALE times.
125- // Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
126- // correctly computes an entire result sub-tile.
101+ // / Lower a contraction operation that performs a matrix multiplication
102+ // / of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
103+ // / for the left-hand side and the right-hand side, respectively,
104+ // / yielding a <Mx[N]> 32-bit integer result.
105+ // /
106+ // / The operands' shapes are such that the operands can be evenly split into
107+ // / sub-tiles with dimensions as expected by the targeted FEAT_I8MM
108+ // / instructions. The intent is that M and N are chosen (by higher level
109+ // / transforms) in such a way as to maximise register usage. The main use case
110+ // / we envision as of now is MMT4D, thus the RHS operand is expected
111+ // / pre-transposed.
112+ // /
113+ // / The matrix multiplication is performed by unrolling the usual tiled matrix
114+ // / multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
115+ // / <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
116+ // /
117+ // / One way to illustrate the operation is as follows:
118+ // /
119+ // / RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
120+ // / +-----------------------------
121+ // / LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
122+ // / <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
123+ // / ... | ... ... ... ...
124+ // / <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
125+ // /
126+ // / The RHS operand is unpacked into N/2 values, each representing a sequence of
127+ // / VSCALE number of sub-tiles with dimensions <8x2>.
128+ // / The LHS operand is initially unpacked into M/2 values, each representing a
129+ // / sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
130+ // / VSCALE times.
131+ // / Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
132+ // / correctly computes an entire result sub-tile.
127133class LowerContractionToSVEI8MMPattern
128134 : public OpRewritePattern<vector::ContractionOp> {
129135public:
@@ -135,27 +141,30 @@ class LowerContractionToSVEI8MMPattern
135141 mlir::VectorType lhsType = op.getLhsType ();
136142 mlir::VectorType rhsType = op.getRhsType ();
137143
138- // Check the operands have the expected shape. M and N dimensions must be
139- // even and at least 2.
140- if (lhsType.getRank () != 2 || rhsType.getRank () != 2 ||
141- lhsType.isScalable () || !rhsType.isScalable ())
144+ // Check the rank the types so we can safely examine their dimensions.
145+ if (lhsType.getRank () != 2 || rhsType.getRank () != 2 )
142146 return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
143147
144- // M, N, and K are the conventional names for matrix dimensions in the
145- // context of matrix multiplication.
146148 auto M = lhsType.getDimSize (0 );
147149 auto N = rhsType.getDimSize (0 );
148150 auto K = rhsType.getDimSize (1 );
149151
150- if (lhsType.getDimSize (1 ) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
151- N % 2 != 0 || !rhsType.getScalableDims ()[0 ])
152+ // Check the operands have the expected shape:
153+ // * for LHS: fixed vector MxK
154+ // * for RHS: scalable vector [N]xK
155+ // * K == 8
156+ // * M and N even and at least 2
157+ if (lhsType.isScalable () || !rhsType.getScalableDims ()[0 ] ||
158+ rhsType.getScalableDims ()[1 ] || lhsType.getDimSize (1 ) != K || K != 8 ||
159+ M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
160+ !rhsType.getScalableDims ()[0 ])
152161 return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
153162
154163 // Check permutation maps. For now only accept
155164 // lhs: (d0, d1, d2) -> (d0, d2)
156165 // rhs: (d0, d1, d2) -> (d1, d2)
157166 // acc: (d0, d1, d2) -> (d0, d1)
158- // Note: RHS is transposed.
167+ // This corresponds to matrix multiplication with transposed RHS .
159168 if (op.getIndexingMapsArray ()[0 ] !=
160169 AffineMap::getMultiDimMapWithTargets (3 , ArrayRef{0u , 2u },
161170 op.getContext ()) ||
@@ -245,7 +254,7 @@ class LowerContractionToSVEI8MMPattern
245254 }
246255
247256 // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
248- auto RHS = rewriter.create <vector::ShapeCastOp>(
257+ auto rhs = rewriter.create <vector::ShapeCastOp>(
249258 maybeRhs->getLoc (),
250259 VectorType::get (/* shape=*/ 8 * N, rewriter.getI8Type (),
251260 /* scalableDims=*/ {true }),
@@ -255,7 +264,7 @@ class LowerContractionToSVEI8MMPattern
255264 SmallVector<Value> rhsTile;
256265 for (int64_t j = 0 ; j < N; j += 2 )
257266 rhsTile.push_back (
258- rewriter.create <vector::ScalableExtractOp>(loc, nxv16i8, RHS , j * 8 ));
267+ rewriter.create <vector::ScalableExtractOp>(loc, nxv16i8, rhs , j * 8 ));
259268
260269 // Handy types for packing/unpacking of the accumulator tile.
261270 auto accRowTy = VectorType::get (/* shape=*/ N, rewriter.getI32Type (),
0 commit comments