@@ -486,6 +486,73 @@ struct FoldAttentionWithProducerReshapeByExpansion final
486486 linalg::ControlFusionFn controlFoldingReshapes;
487487};
488488
489+ // / Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
490+ // / The dims in `update` between the batch dims and the continuous slice
491+ // / represent the indexed dimensions. Remove the leading unit dims from the
492+ // / indexed dims.
493+ struct FoldScatterNonIterationUnitDims final
494+ : public OpRewritePattern<ScatterOp> {
495+ FoldScatterNonIterationUnitDims (MLIRContext *context,
496+ linalg::ControlDropUnitDims options,
497+ PatternBenefit benefit = 1 )
498+ : OpRewritePattern<ScatterOp>(context, benefit),
499+ options (std::move(options)) {}
500+
501+ LogicalResult matchAndRewrite (ScatterOp scatterOp,
502+ PatternRewriter &rewriter) const override {
503+ if (options.rankReductionStrategy !=
504+ linalg::ControlDropUnitDims::RankReductionStrategy::
505+ ReassociativeReshape) {
506+ return rewriter.notifyMatchFailure (
507+ scatterOp, " Only reassociative reshape strategy supported" );
508+ }
509+ llvm::SmallVector<unsigned > canDrop = options.controlFn (scatterOp);
510+ const ArrayRef<int64_t > updateShape = scatterOp.getUpdateType ().getShape ();
511+
512+ // Find the number of leading unit dimensions
513+ int64_t rankOfContiguousSlice =
514+ scatterOp.getOriginalType ().getRank () - scatterOp.getIndexDepth ();
515+ ArrayRef<int64_t > indexedDims =
516+ scatterOp.getUpdateSliceShape ().drop_back (rankOfContiguousSlice);
517+ int64_t numDimsToDrop =
518+ llvm::find_if (indexedDims, [](int64_t val) { return val != 1 ; }) -
519+ scatterOp.getUpdateSliceShape ().begin () - 1 ;
520+
521+ int64_t batchRank = scatterOp.getBatchRank ();
522+ llvm::erase_if (canDrop, [&](unsigned dimPos) {
523+ return dimPos < batchRank || dimPos > batchRank + numDimsToDrop;
524+ });
525+ if (canDrop.empty ()) {
526+ return failure ();
527+ }
528+
529+ SmallVector<int64_t > droppedUpdateShape;
530+ droppedUpdateShape.reserve (updateShape.size () - canDrop.size ());
531+ for (auto [idx, dimLen] : llvm::enumerate (updateShape)) {
532+ if (!llvm::is_contained (canDrop, idx)) {
533+ droppedUpdateShape.push_back (dimLen);
534+ }
535+ }
536+
537+ auto reassoc =
538+ getReassociationIndicesForCollapse (updateShape, droppedUpdateShape);
539+ assert (reassoc.has_value () && " expected reassociation to be valid" );
540+ auto collapseOp = rewriter.create <tensor::CollapseShapeOp>(
541+ scatterOp.getLoc (),
542+ RankedTensorType::get (droppedUpdateShape,
543+ scatterOp.getUpdateType ().getElementType ()),
544+ scatterOp.getUpdates (), reassoc.value ());
545+
546+ rewriter.modifyOpInPlace (scatterOp, [&]() {
547+ scatterOp.setOperand (ScatterOp::kUpdatesOpNum , collapseOp.getResult ());
548+ });
549+ return success ();
550+ }
551+
552+ private:
553+ linalg::ControlDropUnitDims options;
554+ };
555+
489556} // namespace
490557
491558// / Return the `reassociation` indices to use to collapse the operand when the
@@ -708,4 +775,14 @@ void populateFoldReshapeOpsByExpansionPatterns(
708775 patterns.getContext (), controlFoldingReshapes);
709776}
710777
778+ SmallVector<unsigned > defaultControlDropUnitDims (Operation *op) {
779+ auto fusionOp = cast<LinalgFusionOpInterface>(op);
780+ return llvm::to_vector (llvm::seq<unsigned >(0 , fusionOp.getNumLoops ()));
781+ }
782+
783+ void populateFoldUnitExtentDimsPatterns (
784+ RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
785+ patterns.add <FoldScatterNonIterationUnitDims>(patterns.getContext (), options);
786+ }
787+
711788} // namespace mlir::iree_compiler::IREE::LinalgExt
0 commit comments