@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
264264// / iterator_types = ["parallel", "parallel", "reduction"],
265265// / kind = add} %arg0, %arg1, %cst_f0
266266// / : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267- // / ```
268- struct CombineContractBroadcast
269- : public OpRewritePattern<vector::ContractionOp> {
270- using OpRewritePattern::OpRewritePattern;
271-
272- LogicalResult matchAndRewrite (vector::ContractionOp contractOp,
273- PatternRewriter &rewriter) const override {
274- SmallVector<AffineMap> maps =
275- llvm::to_vector<4 >(contractOp.getIndexingMapsArray ());
276- Value lhs = contractOp.getLhs ();
277- Value rhs = contractOp.getRhs ();
278- size_t index = 0 ;
279- bool changed = false ;
280- for (Value *operand : {&lhs, &rhs}) {
281- AffineMap &map = maps[index++];
282- auto broadcast = operand->getDefiningOp <vector::BroadcastOp>();
283- if (!broadcast)
284- continue ;
285- // contractionOp can only take vector as operands.
286- auto srcType = dyn_cast<VectorType>(broadcast.getSourceType ());
287- if (!srcType ||
288- srcType.getRank () == broadcast.getResultVectorType ().getRank ())
289- continue ;
290- int64_t rankDiff =
291- broadcast.getResultVectorType ().getRank () - srcType.getRank ();
292- bool innerDimBroadcast = false ;
293- SmallVector<AffineExpr> originalDims;
294- for (const auto &dim : llvm::enumerate (srcType.getShape ())) {
295- if (dim.value () != broadcast.getResultVectorType ().getDimSize (
296- rankDiff + dim.index ())) {
297- innerDimBroadcast = true ;
298- break ;
299- }
300- originalDims.push_back (
301- rewriter.getAffineDimExpr (dim.index () + rankDiff));
267+ // / ```
268+ // /
269+ // / For masked vector.contract, the mask requires updating when a dimension is
270+ // / dropped. In such cases, the dropped dimensions must correspond to the mask's
271+ // / leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272+ // / is not supported.
273+ FailureOr<Value> combineContractAndBroadcast (vector::ContractionOp contractOp,
274+ MaskingOpInterface maskingOp,
275+ PatternRewriter &rewriter) {
276+ SmallVector<AffineMap> maps =
277+ llvm::to_vector<4 >(contractOp.getIndexingMapsArray ());
278+ Value lhs = contractOp.getLhs ();
279+ Value rhs = contractOp.getRhs ();
280+ size_t index = 0 ;
281+ bool changed = false ;
282+ for (Value *operand : {&lhs, &rhs}) {
283+ AffineMap &map = maps[index++];
284+ auto broadcast = operand->getDefiningOp <vector::BroadcastOp>();
285+ if (!broadcast)
286+ continue ;
287+ // contractionOp can only take vector as operands.
288+ auto srcType = dyn_cast<VectorType>(broadcast.getSourceType ());
289+ if (!srcType ||
290+ srcType.getRank () == broadcast.getResultVectorType ().getRank ())
291+ continue ;
292+ int64_t rankDiff =
293+ broadcast.getResultVectorType ().getRank () - srcType.getRank ();
294+ bool innerDimBroadcast = false ;
295+ SmallVector<AffineExpr> originalDims;
296+ for (const auto &dim : llvm::enumerate (srcType.getShape ())) {
297+ if (dim.value () !=
298+ broadcast.getResultVectorType ().getDimSize (rankDiff + dim.index ())) {
299+ innerDimBroadcast = true ;
300+ break ;
302301 }
303- // Contract doesn't support inner dimension broadcast. Once this is
304- // relaxed we can remove this case.
305- if (innerDimBroadcast)
306- continue ;
302+ originalDims.push_back (rewriter.getAffineDimExpr (dim.index () + rankDiff));
303+ }
304+ // Contract doesn't support inner dimension broadcast. Once this is
305+ // relaxed we can remove this case.
306+ if (innerDimBroadcast)
307+ continue ;
307308
308- // It would be incorrect to fold a broadcast onto a reduction dimension
309- // of non-unit size.
310- bool nonUnitDimReductionBroadcast = false ;
311- for (int64_t i = 0 ; i < rankDiff; ++i) {
312- if (broadcast.getResultVectorType ().getDimSize (i) != 1 &&
313- isReductionIterator (contractOp.getIteratorTypes ()
314- .getValue ()[map.getDimPosition (i)])) {
315- nonUnitDimReductionBroadcast = true ;
316- break ;
317- }
309+ // It would be incorrect to fold a broadcast onto a reduction dimension
310+ // of non-unit size.
311+ bool nonUnitDimReductionBroadcast = false ;
312+ for (int64_t i = 0 ; i < rankDiff; ++i) {
313+ if (broadcast.getResultVectorType ().getDimSize (i) != 1 &&
314+ isReductionIterator (contractOp.getIteratorTypes ()
315+ .getValue ()[map.getDimPosition (i)])) {
316+ nonUnitDimReductionBroadcast = true ;
317+ break ;
318318 }
319- if (nonUnitDimReductionBroadcast)
320- continue ;
321-
322- AffineMap broadcastMap =
323- AffineMap::get (broadcast.getResultVectorType ().getRank (), 0 ,
324- originalDims, contractOp.getContext ());
325- map = broadcastMap.compose (map);
326- *operand = broadcast.getSource ();
327- changed = true ;
328319 }
320+ if (nonUnitDimReductionBroadcast)
321+ continue ;
329322
330- if (!changed)
331- return failure ();
323+ AffineMap broadcastMap =
324+ AffineMap::get (broadcast.getResultVectorType ().getRank (), 0 ,
325+ originalDims, contractOp.getContext ());
326+ map = broadcastMap.compose (map);
327+ *operand = broadcast.getSource ();
328+ changed = true ;
329+ }
332330
333- // Determine which dims are usused, now that the maps have been composed
334- // with the broadcast maps.
335- llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector (maps);
336- // Compress unused dims.
337- for (auto &m : maps)
338- m = compressDims (m, unusedDimsBitVector);
339- // Compute the combined iterators.
340- SmallVector<Attribute> iterators;
341- for (unsigned i = 0 ; i < unusedDimsBitVector.size (); ++i) {
342- if (!unusedDimsBitVector.test (i))
343- iterators.push_back (contractOp.getIteratorTypes ().getValue ()[i]);
344- }
345- // Check that compressing unused dims isn't removing all reduction dimension
346- // pairs. For example, if the vector.contract had only one reduction
347- // iterator and that was a unit-dimension created by a broadcast,
348- // then we should bail here, otherwise we would create a contract without
349- // a reduction dimension pair.
350- bool hasReductionIteratorApplyingOnBothSides = false ;
351- for (unsigned i = 0 ; i < iterators.size (); ++i) {
352- if (!isReductionIterator (iterators[i]))
353- continue ;
354- if (getResultIndex (maps[0 ], i) && getResultIndex (maps[1 ], i)) {
355- hasReductionIteratorApplyingOnBothSides = true ;
331+ if (!changed)
332+ return failure ();
333+
334+ // Determine which dims are usused, now that the maps have been composed
335+ // with the broadcast maps.
336+ llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector (maps);
337+ // Compress unused dims.
338+ for (auto &m : maps)
339+ m = compressDims (m, unusedDimsBitVector);
340+ // Compute the combined iterators.
341+ SmallVector<Attribute> iterators;
342+ for (unsigned i = 0 , e = unusedDimsBitVector.size (); i < e; ++i) {
343+ if (!unusedDimsBitVector.test (i))
344+ iterators.push_back (contractOp.getIteratorTypes ().getValue ()[i]);
345+ }
346+
347+ // Check whether any of the unused dims is non-unit, e.g.:
348+ // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349+ // This is only required when collapsing a mask. If there is no mask, skip.
350+ VectorType oldMaskType;
351+ bool isAnyUnusedDimNonUnit = false ;
352+ if (maskingOp) {
353+ oldMaskType = cast<VectorType>(maskingOp.getMask ().getType ());
354+ for (unsigned i = 0 , e = unusedDimsBitVector.size (); i < e; ++i) {
355+ if (unusedDimsBitVector.test (i) && oldMaskType.getShape ()[i] != 1 ) {
356+ isAnyUnusedDimNonUnit = true ;
356357 break ;
357358 }
358359 }
359- if (!hasReductionIteratorApplyingOnBothSides)
360- return failure ();
360+ }
361361
362- // If the compressed maps have a dimension that is not used by either LHS or
363- // RHS then the ContractionOp verifier would fail.
364- if (getUnusedDimsBitVector ({maps[0 ], maps[1 ]}).any ())
365- return failure ();
366- rewriter.replaceOpWithNewOp <vector::ContractionOp>(
367- contractOp, lhs, rhs, contractOp.getAcc (),
368- rewriter.getAffineMapArrayAttr (maps), rewriter.getArrayAttr (iterators));
369- return success ();
362+ // Check that compressing unused dims isn't removing all reduction dimension
363+ // pairs. For example, if the vector.contract had only one reduction
364+ // iterator and that was a unit-dimension created by a broadcast,
365+ // then we should bail here, otherwise we would create a contract without
366+ // a reduction dimension pair.
367+ bool hasReductionIteratorApplyingOnBothSides = false ;
368+ for (unsigned i = 0 ; i < iterators.size (); ++i) {
369+ if (!isReductionIterator (iterators[i]))
370+ continue ;
371+ if (getResultIndex (maps[0 ], i) && getResultIndex (maps[1 ], i)) {
372+ hasReductionIteratorApplyingOnBothSides = true ;
373+ break ;
374+ }
375+ }
376+ if (!hasReductionIteratorApplyingOnBothSides)
377+ return failure ();
378+
379+ // If the compressed maps have a dimension that is not used by either LHS or
380+ // RHS then the ContractionOp verifier would fail.
381+ if (getUnusedDimsBitVector ({maps[0 ], maps[1 ]}).any ())
382+ return failure ();
383+
384+ Operation *newOp = rewriter.create <vector::ContractionOp>(
385+ contractOp.getLoc (), lhs, rhs, contractOp.getAcc (),
386+ rewriter.getAffineMapArrayAttr (maps), rewriter.getArrayAttr (iterators));
387+
388+ // Handle the mask.
389+ if (maskingOp) {
390+ if (isAnyUnusedDimNonUnit)
391+ return rewriter.notifyMatchFailure (contractOp,
392+ " Cannont drop non-unit mask dim." );
393+ assert (unusedDimsBitVector.size () ==
394+ static_cast <size_t >(oldMaskType.getRank ()) &&
395+ " The mask rank is incorrect!" );
396+
397+ // If a dimension has been dropped, update the mask accordingly. Otherwise,
398+ // keep it as is.
399+ Value mask = maskingOp.getMask ();
400+ if (unusedDimsBitVector.count () != 0 ) {
401+ // At this point, two assumptions are made:
402+ // * The unused dimensions are the leading mask dimensions
403+ // (vector.contract does not support inner dim broadcasting).
404+ // * The unused dimensions are all unit.
405+ // These conditions are effectively verified in the blocks preceeding this
406+ // one.
407+ auto newShape =
408+ oldMaskType.getShape ().drop_front (unusedDimsBitVector.count ());
409+ auto newShapeScalableDims =
410+ oldMaskType.getScalableDims ().drop_front (unusedDimsBitVector.count ());
411+ VectorType maskOpType =
412+ VectorType::get (newShape, rewriter.getI1Type (), newShapeScalableDims);
413+ mask = rewriter
414+ .create <vector::ShapeCastOp>(contractOp.getLoc (), maskOpType,
415+ maskingOp.getMask ())
416+ .getResult ();
417+ }
418+
419+ newOp = mlir::vector::maskOperation (rewriter, newOp, mask);
420+ }
421+ return newOp->getResult (0 );
422+ }
423+
424+ struct CombineContractBroadcastMask
425+ : public MaskableOpRewritePattern<vector::ContractionOp> {
426+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
427+ FailureOr<Value>
428+
429+ matchAndRewriteMaskableOp (vector::ContractionOp contractOp,
430+ MaskingOpInterface maskingOp,
431+ PatternRewriter &rewriter) const override {
432+ return combineContractAndBroadcast (contractOp, maskingOp, rewriter);
370433 }
371434};
372435
@@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
22372300
22382301void mlir::vector::populateVectorReductionToContractPatterns (
22392302 RewritePatternSet &patterns, PatternBenefit benefit) {
2240- patterns.add <MultiReduceToContract, CombineContractBroadcast ,
2303+ patterns.add <MultiReduceToContract, CombineContractBroadcastMask ,
22412304 CombineContractABTranspose, CombineContractResultTranspose>(
22422305 patterns.getContext (), benefit);
22432306}
0 commit comments