2525#include " mlir/IR/PatternMatch.h"
2626#include " mlir/Support/LLVM.h"
2727#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28+ #include " mlir/Transforms/RegionUtils.h"
2829#include < optional>
2930#include < utility>
3031
@@ -590,44 +591,45 @@ class ExpansionInfo {
590591 // the expanded op.
591592 LogicalResult compute (LinalgOp linalgOp, OpOperand *fusableOpOperand,
592593 ArrayRef<AffineMap> reassociationMaps,
593- ArrayRef<int64_t > expandedShape,
594- ArrayRef<int64_t > collapsedShape,
594+ ArrayRef<OpFoldResult> expandedShape,
595595 PatternRewriter &rewriter);
596596 unsigned getOrigOpNumDims () const { return reassociation.size (); }
597597 unsigned getExpandedOpNumDims () const { return expandedOpNumDims; }
598598 ReassociationIndicesRef getExpandedDims (unsigned i) const {
599599 return reassociation[i];
600600 }
601- ArrayRef<int64_t > getExpandedShapeOfDim (unsigned i) const {
601+ ArrayRef<OpFoldResult > getExpandedShapeOfDim (unsigned i) const {
602602 return expandedShapeMap[i];
603603 }
604- ArrayRef<int64_t > getOriginalShape () const { return originalLoopExtent; }
604+ ArrayRef<OpFoldResult > getOriginalShape () const { return originalLoopExtent; }
605605
606606private:
607607 // / Reassociation from the dimensions in the original operation to the
608608 // / dimension of the expanded operation.
609609 SmallVector<ReassociationIndices> reassociation;
610610 // / Mapping from extent of loops in the original operation, to the extent of
611611 // / loops in the expanded operation.
612- SmallVector<SmallVector<int64_t >> expandedShapeMap;
612+ SmallVector<SmallVector<OpFoldResult >> expandedShapeMap;
613613 // / Extent of the loop in the original operation.
614- SmallVector<int64_t > originalLoopExtent;
614+ SmallVector<OpFoldResult > originalLoopExtent;
615615 unsigned expandedOpNumDims;
616616};
617617} // namespace
618618
619619LogicalResult ExpansionInfo::compute (LinalgOp linalgOp,
620620 OpOperand *fusableOpOperand,
621621 ArrayRef<AffineMap> reassociationMaps,
622- ArrayRef<int64_t > expandedShape,
623- ArrayRef<int64_t > collapsedShape,
622+ ArrayRef<OpFoldResult> expandedShape,
624623 PatternRewriter &rewriter) {
625624 if (reassociationMaps.empty ())
626625 return failure ();
627626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap (fusableOpOperand);
628627
629- SmallVector<int64_t , 4 > originalLoopRange = linalgOp.getStaticLoopRanges ();
630- originalLoopExtent.assign (originalLoopRange.begin (), originalLoopRange.end ());
628+ OpBuilder::InsertionGuard g (rewriter);
629+ rewriter.setInsertionPoint (linalgOp);
630+ originalLoopExtent = llvm::map_to_vector (
631+ linalgOp.createLoopRanges (rewriter, linalgOp->getLoc ()),
632+ [](Range r) { return r.size ; });
631633
632634 reassociation.clear ();
633635 expandedShapeMap.clear ();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
639641 unsigned pos = cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
640642 AffineMap foldedDims = reassociationMaps[resultExpr.index ()];
641643 numExpandedDims[pos] = foldedDims.getNumResults ();
642- ArrayRef<int64_t > shape =
644+ ArrayRef<OpFoldResult > shape =
643645 expandedShape.slice (foldedDims.getDimPosition (0 ), numExpandedDims[pos]);
644646 expandedShapeMap[pos].assign (shape.begin (), shape.end ());
645647 }
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
660662 return success ();
661663}
662664
663- // / Expanding the body of a linalg operation requires adaptations of the
664- // / accessed loop indices. Specifically, access of indices in the original
665- // / operation need to be replaced with linearizations of indices in the expanded
666- // / op. That requires the shape of the expanded dimensions to be static (at
667- // / least all but the most significant). For now check that these are all
668- // / statically sized. Note that this could be extended to handle dynamic case,
669- // / but the implementation below uses `affine.apply` which seems to have issues
670- // / when the shapes are not static.
671- static LogicalResult isLinalgOpExpandable (LinalgOp linalgOp,
672- const ExpansionInfo &expansionInfo,
673- PatternRewriter &rewriter) {
674- if (!linalgOp.hasIndexSemantics ())
675- return success ();
676- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
677- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
678- if (expandedShape.size () == 1 )
679- continue ;
680- for (int64_t shape : expandedShape.drop_front ()) {
681- if (ShapedType::isDynamic (shape)) {
682- return rewriter.notifyMatchFailure (
683- linalgOp, " cannot expand due to index semantics and dynamic dims" );
684- }
685- }
686- }
687- return success ();
688- }
689-
690665// / Return the indexing map to use in the expanded op for a given the
691666// / `indexingMap` of the original operation.
692667static AffineMap
@@ -708,16 +683,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
708683
709684// / Return the type of the operand/result to use in the expanded op given the
710685// / type in the original op.
711- static RankedTensorType getExpandedType (RankedTensorType originalType,
712- AffineMap indexingMap,
713- const ExpansionInfo &expansionInfo) {
714- SmallVector<int64_t > expandedShape;
686+ static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687+ getExpandedShapeAndType (RankedTensorType originalType, AffineMap indexingMap,
688+ const ExpansionInfo &expansionInfo) {
689+ SmallVector<int64_t > expandedStaticShape;
690+ SmallVector<OpFoldResult> expandedShape;
715691 for (AffineExpr expr : indexingMap.getResults ()) {
716692 unsigned dim = cast<AffineDimExpr>(expr).getPosition ();
717- auto dimExpansion = expansionInfo.getExpandedShapeOfDim (dim);
693+ ArrayRef<OpFoldResult> dimExpansion =
694+ expansionInfo.getExpandedShapeOfDim (dim);
695+ llvm::append_range (expandedStaticShape,
696+ llvm::map_range (dimExpansion, [](OpFoldResult ofr) {
697+ std::optional<int64_t > staticShape =
698+ getConstantIntValue (ofr);
699+ if (staticShape) {
700+ return staticShape.value ();
701+ }
702+ return ShapedType::kDynamic ;
703+ }));
718704 expandedShape.append (dimExpansion.begin (), dimExpansion.end ());
719705 }
720- return RankedTensorType::get (expandedShape, originalType.getElementType ());
706+ return {expandedShape, RankedTensorType::get (expandedStaticShape,
707+ originalType.getElementType ())};
721708}
722709
723710// / Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +752,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
765752 // Linearize the expanded indices of the original index dimension.
766753 OpBuilder::InsertionGuard guard (rewriter);
767754 rewriter.setInsertionPointAfter (indexOp);
768- ArrayRef<int64_t > expandedDimsShape =
755+ ArrayRef<OpFoldResult > expandedDimsShape =
769756 expansionInfo.getExpandedShapeOfDim (indexOp.getDim ()).drop_front ();
770757 SmallVector<Value> expandedIndices;
771758 expandedIndices.reserve (expandedDims.size () - 1 );
772759 llvm::transform (
773760 expandedDims.drop_front (), std::back_inserter (expandedIndices),
774761 [&](int64_t dim) { return rewriter.create <IndexOp>(loc, dim); });
775- Value newIndex = rewriter.create <IndexOp>(loc, expandedDims.front ());
762+ OpFoldResult newIndex =
763+ rewriter.create <IndexOp>(loc, expandedDims.front ()).getResult ();
776764 for (auto it : llvm::zip (expandedDimsShape, expandedIndices)) {
777- assert (!ShapedType::isDynamic (std::get<0 >(it)));
778- AffineExpr idx, acc;
765+ AffineExpr idx, acc, shape;
779766 bindDims (rewriter.getContext (), idx, acc);
780- newIndex = rewriter.create <affine::AffineApplyOp>(
781- indexOp.getLoc (), idx + acc * std::get<0 >(it),
782- ValueRange{std::get<1 >(it), newIndex});
783- }
784- rewriter.replaceOp (indexOp, newIndex);
785- }
786- }
787-
788- // / Checks if a single dynamic dimension expanded into multiple dynamic
789- // / dimensions.
790- static LogicalResult
791- validateDynamicDimExpansion (LinalgOp linalgOp,
792- const ExpansionInfo &expansionInfo,
793- PatternRewriter &rewriter) {
794- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
795- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
796- if (expandedShape.size () == 1 )
797- continue ;
798- bool foundDynamic = false ;
799- for (int64_t shape : expandedShape) {
800- if (!ShapedType::isDynamic (shape))
801- continue ;
802- if (foundDynamic) {
803- return rewriter.notifyMatchFailure (
804- linalgOp, " cannot infer expanded shape with multiple dynamic "
805- " dims in the same reassociation group" );
806- }
807- foundDynamic = true ;
767+ bindSymbols (rewriter.getContext (), shape);
768+ newIndex = affine::makeComposedFoldedAffineApply (
769+ rewriter, indexOp.getLoc (), idx + acc * shape,
770+ ArrayRef<OpFoldResult>{std::get<1 >(it), newIndex, std::get<0 >(it)});
808771 }
772+ Value newIndexVal =
773+ getValueOrCreateConstantIndexOp (rewriter, indexOp.getLoc (), newIndex);
774+ rewriter.replaceOp (indexOp, newIndexVal);
809775 }
810- return success ();
811776}
812777
813778// Create an expanded transpose op.
@@ -910,31 +875,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
910875 " preconditions for fuse operation failed" );
911876
912877 Location loc = linalgOp.getLoc ();
913- // Check if reshape is expanding or collapsing.
914- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
915- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
916- bool isExpanding = (expandingReshapeOp != nullptr );
917- RankedTensorType expandedType = isExpanding
918- ? expandingReshapeOp.getResultType ()
919- : collapsingReshapeOp.getSrcType ();
920- RankedTensorType collapsedType = isExpanding
921- ? expandingReshapeOp.getSrcType ()
922- : collapsingReshapeOp.getResultType ();
878+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
879+ SmallVector<AffineMap, 4 > reassociationIndices;
880+ Value src;
881+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
882+ // Try to move the dynamic dimensions in output shape before the `linalgOp`
883+ // to maintain SSA validity
884+ if (failed (moveValueDefinitions (
885+ rewriter, expandingReshapeOp.getOutputShape (), linalgOp)))
886+ return std::nullopt ;
887+
888+ expandedShape = expandingReshapeOp.getMixedOutputShape ();
889+ reassociationIndices = expandingReshapeOp.getReassociationMaps ();
890+ src = expandingReshapeOp.getSrc ();
891+ } else {
892+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
893+ expandedShape = tensor::getMixedSizes (
894+ rewriter, collapsingReshapeOp->getLoc (), collapsingReshapeOp.getSrc ());
895+ reassociationIndices = collapsingReshapeOp.getReassociationMaps ();
896+ src = collapsingReshapeOp.getSrc ();
897+ }
923898
924899 ExpansionInfo expansionInfo;
925- if (failed (expansionInfo.compute (
926- linalgOp, fusableOpOperand,
927- isExpanding ? expandingReshapeOp.getReassociationMaps ()
928- : collapsingReshapeOp.getReassociationMaps (),
929- expandedType.getShape (), collapsedType.getShape (), rewriter)))
930- return std::nullopt ;
931-
932- // TODO: With the support of multiple dynamic dims expansion in
933- // tensor.expand_shape op, this case can be handled.
934- if (failed (validateDynamicDimExpansion (linalgOp, expansionInfo, rewriter)))
935- return std::nullopt ;
936-
937- if (failed (isLinalgOpExpandable (linalgOp, expansionInfo, rewriter)))
900+ if (failed (expansionInfo.compute (linalgOp, fusableOpOperand,
901+ reassociationIndices, expandedShape,
902+ rewriter)))
938903 return std::nullopt ;
939904
940905 SmallVector<AffineMap, 4 > expandedOpIndexingMaps = llvm::to_vector<4 >(
@@ -950,15 +915,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
950915 expandedOpOperands.reserve (linalgOp.getNumDpsInputs ());
951916 for (OpOperand *opOperand : linalgOp.getDpsInputOperands ()) {
952917 if (opOperand == fusableOpOperand) {
953- expandedOpOperands.push_back (isExpanding ? expandingReshapeOp.getSrc ()
954- : collapsingReshapeOp.getSrc ());
918+ expandedOpOperands.push_back (src);
955919 continue ;
956920 }
957921 if (auto opOperandType =
958922 dyn_cast<RankedTensorType>(opOperand->get ().getType ())) {
959923 AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
960- RankedTensorType expandedOperandType =
961- getExpandedType (opOperandType, indexingMap, expansionInfo);
924+ SmallVector<OpFoldResult> expandedOperandShape;
925+ RankedTensorType expandedOperandType;
926+ std::tie (expandedOperandShape, expandedOperandType) =
927+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
962928 if (expandedOperandType != opOperand->get ().getType ()) {
963929 // Reshape the operand to get the right type.
964930 SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +938,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
972938 /* isExpandingReshape=*/ true )))
973939 return std::nullopt ;
974940 expandedOpOperands.push_back (rewriter.create <tensor::ExpandShapeOp>(
975- loc, expandedOperandType, opOperand->get (), reassociation));
941+ loc, expandedOperandType, opOperand->get (), reassociation,
942+ expandedOperandShape));
976943 continue ;
977944 }
978945 }
@@ -983,8 +950,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
983950 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable ()) {
984951 AffineMap indexingMap = linalgOp.getMatchingIndexingMap (&opOperand);
985952 auto opOperandType = cast<RankedTensorType>(opOperand.get ().getType ());
986- RankedTensorType expandedOutputType =
987- getExpandedType (opOperandType, indexingMap, expansionInfo);
953+ SmallVector<OpFoldResult> expandedOutputShape;
954+ RankedTensorType expandedOutputType;
955+ std::tie (expandedOutputShape, expandedOutputType) =
956+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
988957 if (expandedOutputType != opOperand.get ().getType ()) {
989958 SmallVector<ReassociationIndices> reassociation =
990959 getReassociationForExpansion (indexingMap, expansionInfo);
@@ -997,7 +966,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
997966 /* isExpandingReshape=*/ true )))
998967 return std::nullopt ;
999968 outputs.push_back (rewriter.create <tensor::ExpandShapeOp>(
1000- loc, expandedOutputType, opOperand.get (), reassociation));
969+ loc, expandedOutputType, opOperand.get (), reassociation,
970+ expandedOutputShape));
1001971 } else {
1002972 outputs.push_back (opOperand.get ());
1003973 }
0 commit comments