1- // ===- LowerContractionToNeonI8MMPattern .cpp - Contract to I8MM -*- C++ -*-===//
1+ // ===- LowerContractToNeonPatterns .cpp - Contract to I8MM/BF16 - -*- C++ -*-===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
@@ -93,15 +93,20 @@ class VectorContractRewriter {
9393 // multiplications.
9494 enum class MMLA {
9595 Nop,
96- Signed, // smmla
97- Unsigned, // ummla
98- Mixed, // usmmla
99- MixedSwapped // usmmla with LHS and RHS swapped
96+ SignedInt, // smmla
97+ UnsignedInt, // ummla
98+ MixedInt, // usmmla
99+ Bfloat // bfmmla
100100 };
101101
102102 // Lower-level operation to be emitted.
103103 MMLA mmlaOp = MMLA::Nop;
104104
105+ // Indicate if the operands for the ArmNeon dialect operation need to be
106+ // swapped. Currently this is needed in order to emulate an "summla"
107+ // operation.
108+ bool swapOperands = false ;
109+
105110 // The operand tiles. These are not necessarily the operands of
106111 // `vector.contract`, for example they could be operands to `arith.extsi`
107112 // that is in turn fed into `vector.contract`.
@@ -126,21 +131,22 @@ class VectorContractRewriter {
126131 // Create the matrix multiply and accumulate operation according to `mmlaOp`.
127132 Value createMMLA (PatternRewriter &rewriter, Location loc, Value acc,
128133 Value lhs, Value rhs) {
134+
135+ if (swapOperands)
136+ std::swap (lhs, rhs);
129137 switch (mmlaOp) {
130- case MMLA::Signed :
138+ case MMLA::SignedInt :
131139 return rewriter.createOrFold <arm_neon::SmmlaOp>(loc, acc.getType (), acc,
132140 lhs, rhs);
133- case MMLA::Unsigned :
141+ case MMLA::UnsignedInt :
134142 return rewriter.createOrFold <arm_neon::UmmlaOp>(loc, acc.getType (), acc,
135143 lhs, rhs);
136- case MMLA::Mixed :
144+ case MMLA::MixedInt :
137145 return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, acc.getType (), acc,
138146 lhs, rhs);
139- case MMLA::MixedSwapped:
140- // The accumulator comes transposed and the result will be transposed
141- // later, so all we have to do here is swap the operands.
142- return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, acc.getType (), acc,
143- rhs, lhs);
147+ case MMLA::Bfloat:
148+ return rewriter.create <arm_neon::BfmmlaOp>(loc, acc.getType (), acc, lhs,
149+ rhs);
144150 case MMLA::Nop:
145151 llvm_unreachable (" Uninitialized operation type" );
146152 }
@@ -273,7 +279,7 @@ class VectorContractRewriter {
273279 // Transpose ACC if doing signed by unsigned multiplication, because we're
274280 // using the instruction for unsigned by signed multiplication with
275281 // reversed operands.
276- if (mmlaOp == MMLA::MixedSwapped )
282+ if (swapOperands )
277283 tiledAcc = rewriter.create <vector::TransposeOp>(
278284 loc, tiledAcc, ArrayRef<int64_t >({1 , 0 }));
279285
@@ -302,7 +308,7 @@ class VectorContractRewriter {
302308
303309 // Because of the reversed operands the result is obtained transposed.
304310 // Transpose it back,
305- if (mmlaOp == MMLA::MixedSwapped )
311+ if (swapOperands )
306312 tiledRes = rewriter.create <vector::TransposeOp>(
307313 loc, tiledRes, ArrayRef<int64_t >({1 , 0 }));
308314
@@ -339,10 +345,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
339345 // values before the extension. All four signed/unsigned combinations for
340346 // input operands are supported, but they are lowered to different
341347 // operations. Determine which is the appropriate operation to lower to.
342- mmlaOp = MMLA::Signed ;
348+ mmlaOp = MMLA::SignedInt ;
343349 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
344350 if (!maybeLhs) {
345- mmlaOp = MMLA::Unsigned ;
351+ mmlaOp = MMLA::UnsignedInt ;
346352 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
347353 }
348354 if (!maybeLhs)
@@ -351,11 +357,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
351357
352358 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
353359 if (maybeRhs) {
354- if (mmlaOp == MMLA::Unsigned )
355- mmlaOp = MMLA::Mixed ;
360+ if (mmlaOp == MMLA::UnsignedInt )
361+ mmlaOp = MMLA::MixedInt ;
356362 } else {
357- if (mmlaOp == MMLA::Signed)
358- mmlaOp = MMLA::MixedSwapped;
363+ if (mmlaOp == MMLA::SignedInt) {
364+ mmlaOp = MMLA::MixedInt;
365+ swapOperands = true ;
366+ }
359367 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
360368 }
361369
@@ -372,16 +380,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
372380 auto lhsExtInType = cast<VectorType>(lhs.getType ());
373381 if (lhsExtInType.getElementTypeBitWidth () < 8 )
374382 lhs = extendSmallIntVector (loc, lhsExtInType, lhs,
375- /* signExt */ mmlaOp == MMLA::Signed ||
376- mmlaOp == MMLA::Mixed,
383+ /* signExt */
384+ (mmlaOp == MMLA::SignedInt ||
385+ (mmlaOp == MMLA::MixedInt && !swapOperands)),
377386 rewriter);
378387
379388 auto rhsExtInType = cast<VectorType>(rhs.getType ());
380389 if (rhsExtInType.getElementTypeBitWidth () < 8 )
381-
382390 rhs = extendSmallIntVector (loc, rhsExtInType, rhs,
383- /* signExt */ mmlaOp != MMLA::Unsigned &&
384- mmlaOp != MMLA::Mixed,
391+ /* signExt */
392+ (mmlaOp == MMLA::SignedInt ||
393+ (mmlaOp == MMLA::MixedInt && swapOperands)),
385394 rewriter);
386395
387396 // Initialize parameters for unrolling.
@@ -395,6 +404,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
395404 }
396405};
397406
407+ class VectorContractRewriterBFMMLA : public VectorContractRewriter {
408+ public:
409+ LogicalResult matchAndInit (vector::ContractionOp op,
410+ PatternRewriter &rewriter) {
411+
412+ if (failed (VectorContractRewriter::matchAndInit (op, rewriter)))
413+ return failure ();
414+
415+ // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
416+ // tiling.
417+ if ((dimM != 1 && dimM % 2 != 0 ) || dimN % 2 != 0 || dimK % 4 != 0 )
418+ return rewriter.notifyMatchFailure (op, " Unsupported operand shapes" );
419+
420+ // Check the output is a vector of Float32 elements.
421+ auto outTy = dyn_cast<VectorType>(op.getResultType ());
422+ if (!outTy || outTy.getElementType () != rewriter.getF32Type ())
423+ return rewriter.notifyMatchFailure (op,
424+ " output type is not a vector of f32" );
425+
426+ // Check the inputs are vectors of BFloat16 elements.
427+ if (op.getLhsType ().getElementType () != rewriter.getBF16Type ())
428+ return rewriter.notifyMatchFailure (op,
429+ " input type is not a vector of bf16" );
430+
431+ mmlaOp = MMLA::Bfloat;
432+ swapOperands = false ;
433+ lhs = op.getLhs ();
434+ rhs = op.getRhs ();
435+ acc = op.getAcc ();
436+
437+ // Initialize parameters for unrolling.
438+ iterationBounds = *op.getShapeForUnroll ();
439+ if (iterationBounds.size () == 3 )
440+ subTileShape = SmallVector<int64_t >({dimM == 1 ? 1 : 2 , 2 , 4 });
441+ else
442+ subTileShape = SmallVector<int64_t >({2 , 4 });
443+
444+ return success ();
445+ }
446+ };
447+
398448// / Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
399449// / any vector.contract into multiple smmla instructions with unrolling so long
400450// / as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -416,10 +466,32 @@ class LowerContractionToNeonI8MMPattern
416466 }
417467};
418468
469+ class LowerContractionToNeonBFMMLAPattern
470+ : public OpRewritePattern<vector::ContractionOp> {
471+ public:
472+ using OpRewritePattern::OpRewritePattern;
473+ LogicalResult matchAndRewrite (vector::ContractionOp op,
474+ PatternRewriter &rewriter) const override {
475+
476+ VectorContractRewriterBFMMLA vcr;
477+ if (failed (vcr.matchAndInit (op, rewriter)))
478+ return failure ();
479+ vcr.lower (op, rewriter);
480+
481+ return success ();
482+ }
483+ };
484+
419485} // namespace
420486
421- void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns (
487+ void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns (
422488 RewritePatternSet &patterns) {
423489 MLIRContext *context = patterns.getContext ();
424490 patterns.add <LowerContractionToNeonI8MMPattern>(context, /* benefit=*/ 2 );
425491}
492+
493+ void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns (
494+ RewritePatternSet &patterns) {
495+ MLIRContext *context = patterns.getContext ();
496+ patterns.add <LowerContractionToNeonBFMMLAPattern>(context, /* benefit=*/ 2 );
497+ }
0 commit comments