2525#include " mlir/IR/PatternMatch.h"
2626#include " mlir/Support/LLVM.h"
2727#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28+ #include " mlir/Transforms/RegionUtils.h"
2829#include < optional>
2930#include < utility>
3031
@@ -590,44 +591,45 @@ class ExpansionInfo {
590591 // the expanded op.
591592 LogicalResult compute (LinalgOp linalgOp, OpOperand *fusableOpOperand,
592593 ArrayRef<AffineMap> reassociationMaps,
593- ArrayRef<int64_t > expandedShape,
594- ArrayRef<int64_t > collapsedShape,
594+ ArrayRef<OpFoldResult> expandedShape,
595595 PatternRewriter &rewriter);
596596 unsigned getOrigOpNumDims () const { return reassociation.size (); }
597597 unsigned getExpandedOpNumDims () const { return expandedOpNumDims; }
598598 ReassociationIndicesRef getExpandedDims (unsigned i) const {
599599 return reassociation[i];
600600 }
601- ArrayRef<int64_t > getExpandedShapeOfDim (unsigned i) const {
601+ ArrayRef<OpFoldResult > getExpandedShapeOfDim (unsigned i) const {
602602 return expandedShapeMap[i];
603603 }
604- ArrayRef<int64_t > getOriginalShape () const { return originalLoopExtent; }
604+ ArrayRef<OpFoldResult > getOriginalShape () const { return originalLoopExtent; }
605605
606606private:
607607 // / Reassociation from the dimensions in the original operation to the
608608 // / dimension of the expanded operation.
609609 SmallVector<ReassociationIndices> reassociation;
610610 // / Mapping from extent of loops in the original operation, to the extent of
611611 // / loops in the expanded operation.
612- SmallVector<SmallVector<int64_t >> expandedShapeMap;
612+ SmallVector<SmallVector<OpFoldResult >> expandedShapeMap;
613613 // / Extent of the loop in the original operation.
614- SmallVector<int64_t > originalLoopExtent;
614+ SmallVector<OpFoldResult > originalLoopExtent;
615615 unsigned expandedOpNumDims;
616616};
617617} // namespace
618618
619619LogicalResult ExpansionInfo::compute (LinalgOp linalgOp,
620620 OpOperand *fusableOpOperand,
621621 ArrayRef<AffineMap> reassociationMaps,
622- ArrayRef<int64_t > expandedShape,
623- ArrayRef<int64_t > collapsedShape,
622+ ArrayRef<OpFoldResult> expandedShape,
624623 PatternRewriter &rewriter) {
625624 if (reassociationMaps.empty ())
626625 return failure ();
627626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap (fusableOpOperand);
628627
629- SmallVector<int64_t , 4 > originalLoopRange = linalgOp.getStaticLoopRanges ();
630- originalLoopExtent.assign (originalLoopRange.begin (), originalLoopRange.end ());
628+ OpBuilder::InsertionGuard g (rewriter);
629+ rewriter.setInsertionPoint (linalgOp);
630+ originalLoopExtent = llvm::map_to_vector (
631+ linalgOp.createLoopRanges (rewriter, linalgOp->getLoc ()),
632+ [](Range r) { return r.size ; });
631633
632634 reassociation.clear ();
633635 expandedShapeMap.clear ();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
639641 unsigned pos = cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
640642 AffineMap foldedDims = reassociationMaps[resultExpr.index ()];
641643 numExpandedDims[pos] = foldedDims.getNumResults ();
642- ArrayRef<int64_t > shape =
644+ ArrayRef<OpFoldResult > shape =
643645 expandedShape.slice (foldedDims.getDimPosition (0 ), numExpandedDims[pos]);
644646 expandedShapeMap[pos].assign (shape.begin (), shape.end ());
645647 }
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
660662 return success ();
661663}
662664
663- // / Expanding the body of a linalg operation requires adaptations of the
664- // / accessed loop indices. Specifically, access of indices in the original
665- // / operation need to be replaced with linearizations of indices in the expanded
666- // / op. That requires the shape of the expanded dimensions to be static (at
667- // / least all but the most significant). For now check that these are all
668- // / statically sized. Note that this could be extended to handle dynamic case,
669- // / but the implementation below uses `affine.apply` which seems to have issues
670- // / when the shapes are not static.
671- static LogicalResult isLinalgOpExpandable (LinalgOp linalgOp,
672- const ExpansionInfo &expansionInfo,
673- PatternRewriter &rewriter) {
674- if (!linalgOp.hasIndexSemantics ())
675- return success ();
676- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
677- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
678- if (expandedShape.size () == 1 )
679- continue ;
680- for (int64_t shape : expandedShape.drop_front ()) {
681- if (ShapedType::isDynamic (shape)) {
682- return rewriter.notifyMatchFailure (
683- linalgOp, " cannot expand due to index semantics and dynamic dims" );
684- }
685- }
686- }
687- return success ();
688- }
689-
690665// / Return the indexing map to use in the expanded op for a given the
691666// / `indexingMap` of the original operation.
692667static AffineMap
@@ -706,18 +681,23 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
706681 builder.getContext ());
707682}
708683
709- // / Return the type of the operand/result to use in the expanded op given the
710- // / type in the original op.
711- static RankedTensorType getExpandedType ( RankedTensorType originalType,
712- AffineMap indexingMap,
713- const ExpansionInfo &expansionInfo) {
714- SmallVector<int64_t > expandedShape;
684+ // / Return the shape and type of the operand/result to use in the expanded op
685+ // / given the type in the original op.
686+ static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687+ getExpandedShapeAndType (RankedTensorType originalType, AffineMap indexingMap,
688+ const ExpansionInfo &expansionInfo) {
689+ SmallVector<OpFoldResult > expandedShape;
715690 for (AffineExpr expr : indexingMap.getResults ()) {
716691 unsigned dim = cast<AffineDimExpr>(expr).getPosition ();
717- auto dimExpansion = expansionInfo.getExpandedShapeOfDim (dim);
692+ ArrayRef<OpFoldResult> dimExpansion =
693+ expansionInfo.getExpandedShapeOfDim (dim);
718694 expandedShape.append (dimExpansion.begin (), dimExpansion.end ());
719695 }
720- return RankedTensorType::get (expandedShape, originalType.getElementType ());
696+ SmallVector<int64_t > expandedStaticShape;
697+ std::tie (expandedStaticShape, std::ignore) =
698+ decomposeMixedValues (expandedShape);
699+ return {expandedShape, RankedTensorType::get (expandedStaticShape,
700+ originalType.getElementType ())};
721701}
722702
723703// / Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +745,28 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
765745 // Linearize the expanded indices of the original index dimension.
766746 OpBuilder::InsertionGuard guard (rewriter);
767747 rewriter.setInsertionPointAfter (indexOp);
768- ArrayRef<int64_t > expandedDimsShape =
748+ ArrayRef<OpFoldResult > expandedDimsShape =
769749 expansionInfo.getExpandedShapeOfDim (indexOp.getDim ()).drop_front ();
770750 SmallVector<Value> expandedIndices;
771751 expandedIndices.reserve (expandedDims.size () - 1 );
772752 llvm::transform (
773753 expandedDims.drop_front (), std::back_inserter (expandedIndices),
774754 [&](int64_t dim) { return rewriter.create <IndexOp>(loc, dim); });
775- Value newIndex = rewriter.create <IndexOp>(loc, expandedDims.front ());
776- for (auto it : llvm::zip (expandedDimsShape, expandedIndices)) {
777- assert (!ShapedType::isDynamic (std::get<0 >(it)));
778- AffineExpr idx, acc;
755+ OpFoldResult newIndex =
756+ rewriter.create <IndexOp>(loc, expandedDims.front ()).getResult ();
757+ for (auto [expandedShape, expandedIndex] :
758+ llvm::zip (expandedDimsShape, expandedIndices)) {
759+ AffineExpr idx, acc, shape;
779760 bindDims (rewriter.getContext (), idx, acc);
780- newIndex = rewriter.create <affine::AffineApplyOp>(
781- indexOp.getLoc (), idx + acc * std::get<0 >(it),
782- ValueRange{std::get<1 >(it), newIndex});
783- }
784- rewriter.replaceOp (indexOp, newIndex);
785- }
786- }
787-
788- // / Checks if a single dynamic dimension expanded into multiple dynamic
789- // / dimensions.
790- static LogicalResult
791- validateDynamicDimExpansion (LinalgOp linalgOp,
792- const ExpansionInfo &expansionInfo,
793- PatternRewriter &rewriter) {
794- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
795- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
796- if (expandedShape.size () == 1 )
797- continue ;
798- bool foundDynamic = false ;
799- for (int64_t shape : expandedShape) {
800- if (!ShapedType::isDynamic (shape))
801- continue ;
802- if (foundDynamic) {
803- return rewriter.notifyMatchFailure (
804- linalgOp, " cannot infer expanded shape with multiple dynamic "
805- " dims in the same reassociation group" );
806- }
807- foundDynamic = true ;
761+ bindSymbols (rewriter.getContext (), shape);
762+ newIndex = affine::makeComposedFoldedAffineApply (
763+ rewriter, indexOp.getLoc (), idx + acc * shape,
764+ ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
808765 }
766+ Value newIndexVal =
767+ getValueOrCreateConstantIndexOp (rewriter, indexOp.getLoc (), newIndex);
768+ rewriter.replaceOp (indexOp, newIndexVal);
809769 }
810- return success ();
811770}
812771
813772// Create an expanded transpose op.
@@ -910,31 +869,34 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
910869 " preconditions for fuse operation failed" );
911870
912871 Location loc = linalgOp.getLoc ();
913- // Check if reshape is expanding or collapsing.
914- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
915- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
916- bool isExpanding = (expandingReshapeOp != nullptr );
917- RankedTensorType expandedType = isExpanding
918- ? expandingReshapeOp.getResultType ()
919- : collapsingReshapeOp.getSrcType ();
920- RankedTensorType collapsedType = isExpanding
921- ? expandingReshapeOp.getSrcType ()
922- : collapsingReshapeOp.getResultType ();
872+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
873+ SmallVector<AffineMap, 4 > reassociationIndices;
874+ Value src;
875+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876+ // Try to move the dynamic dimensions in output shape before the `linalgOp`
877+ // to maintain SSA validity
878+ if (failed (moveValueDefinitions (
879+ rewriter, expandingReshapeOp.getOutputShape (), linalgOp)))
880+ return std::nullopt ;
881+
882+ expandedShape = expandingReshapeOp.getMixedOutputShape ();
883+ reassociationIndices = expandingReshapeOp.getReassociationMaps ();
884+ src = expandingReshapeOp.getSrc ();
885+ } else {
886+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887+ if (!collapsingReshapeOp)
888+ return std::nullopt ;
889+
890+ expandedShape = tensor::getMixedSizes (
891+ rewriter, collapsingReshapeOp->getLoc (), collapsingReshapeOp.getSrc ());
892+ reassociationIndices = collapsingReshapeOp.getReassociationMaps ();
893+ src = collapsingReshapeOp.getSrc ();
894+ }
923895
924896 ExpansionInfo expansionInfo;
925- if (failed (expansionInfo.compute (
926- linalgOp, fusableOpOperand,
927- isExpanding ? expandingReshapeOp.getReassociationMaps ()
928- : collapsingReshapeOp.getReassociationMaps (),
929- expandedType.getShape (), collapsedType.getShape (), rewriter)))
930- return std::nullopt ;
931-
932- // TODO: With the support of multiple dynamic dims expansion in
933- // tensor.expand_shape op, this case can be handled.
934- if (failed (validateDynamicDimExpansion (linalgOp, expansionInfo, rewriter)))
935- return std::nullopt ;
936-
937- if (failed (isLinalgOpExpandable (linalgOp, expansionInfo, rewriter)))
897+ if (failed (expansionInfo.compute (linalgOp, fusableOpOperand,
898+ reassociationIndices, expandedShape,
899+ rewriter)))
938900 return std::nullopt ;
939901
940902 SmallVector<AffineMap, 4 > expandedOpIndexingMaps = llvm::to_vector<4 >(
@@ -950,15 +912,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
950912 expandedOpOperands.reserve (linalgOp.getNumDpsInputs ());
951913 for (OpOperand *opOperand : linalgOp.getDpsInputOperands ()) {
952914 if (opOperand == fusableOpOperand) {
953- expandedOpOperands.push_back (isExpanding ? expandingReshapeOp.getSrc ()
954- : collapsingReshapeOp.getSrc ());
915+ expandedOpOperands.push_back (src);
955916 continue ;
956917 }
957918 if (auto opOperandType =
958919 dyn_cast<RankedTensorType>(opOperand->get ().getType ())) {
959920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
960- RankedTensorType expandedOperandType =
961- getExpandedType (opOperandType, indexingMap, expansionInfo);
921+ SmallVector<OpFoldResult> expandedOperandShape;
922+ RankedTensorType expandedOperandType;
923+ std::tie (expandedOperandShape, expandedOperandType) =
924+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
962925 if (expandedOperandType != opOperand->get ().getType ()) {
963926 // Reshape the operand to get the right type.
964927 SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +935,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
972935 /* isExpandingReshape=*/ true )))
973936 return std::nullopt ;
974937 expandedOpOperands.push_back (rewriter.create <tensor::ExpandShapeOp>(
975- loc, expandedOperandType, opOperand->get (), reassociation));
938+ loc, expandedOperandType, opOperand->get (), reassociation,
939+ expandedOperandShape));
976940 continue ;
977941 }
978942 }
@@ -983,8 +947,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
983947 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable ()) {
984948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap (&opOperand);
985949 auto opOperandType = cast<RankedTensorType>(opOperand.get ().getType ());
986- RankedTensorType expandedOutputType =
987- getExpandedType (opOperandType, indexingMap, expansionInfo);
950+ SmallVector<OpFoldResult> expandedOutputShape;
951+ RankedTensorType expandedOutputType;
952+ std::tie (expandedOutputShape, expandedOutputType) =
953+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
988954 if (expandedOutputType != opOperand.get ().getType ()) {
989955 SmallVector<ReassociationIndices> reassociation =
990956 getReassociationForExpansion (indexingMap, expansionInfo);
@@ -997,7 +963,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
997963 /* isExpandingReshape=*/ true )))
998964 return std::nullopt ;
999965 outputs.push_back (rewriter.create <tensor::ExpandShapeOp>(
1000- loc, expandedOutputType, opOperand.get (), reassociation));
966+ loc, expandedOutputType, opOperand.get (), reassociation,
967+ expandedOutputShape));
1001968 } else {
1002969 outputs.push_back (opOperand.get ());
1003970 }
0 commit comments