@@ -814,6 +814,55 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
814814 }
815815 return success ();
816816}
817+ // Create an expanded fused op that retains the name for certain ops
818+ // such as fill, copy and transpose and produce a generic op for
819+ // rest of linalg ops.
820+ Operation *createFusedOpForReshapeByExpansion (
821+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
822+ ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
823+ ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
824+ SmallVector<ReassociationIndices> reassociation) {
825+
826+ return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation ())
827+ .Case <TransposeOp>([&](TransposeOp op) {
828+ applyPermutationToVector (reassociation, op.getPermutation ());
829+ SmallVector<int64_t > newPerm;
830+ for (auto reassoc : reassociation) {
831+ for (auto dim : reassoc) {
832+ newPerm.push_back (dim);
833+ }
834+ }
835+ return rewriter.create <TransposeOp>(
836+ linalgOp.getLoc (), expandedOpOperands[0 ], outputs[0 ], newPerm);
837+ })
838+ .Case <FillOp, CopyOp>([&](Operation *op) {
839+ return clone (rewriter, linalgOp, resultTypes,
840+ llvm::to_vector (llvm::concat<Value>(
841+ llvm::to_vector (expandedOpOperands),
842+ llvm::to_vector (outputs))));
843+ })
844+ .Default ([&](Operation *op) {
845+ // The iterator types of the expanded op are all parallel.
846+ SmallVector<utils::IteratorType> iteratorTypes (
847+ expansionInfo.getExpandedOpNumDims (),
848+ utils::IteratorType::parallel);
849+ for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
850+ for (auto j : expansionInfo.getExpandedDims (i))
851+ iteratorTypes[j] = type;
852+ Operation *fused = rewriter.create <GenericOp>(
853+ linalgOp.getLoc (), resultTypes, expandedOpOperands, outputs,
854+ expandedOpIndexingMaps, iteratorTypes);
855+ Region &fusedRegion = fused->getRegion (0 );
856+ Region &originalRegion = linalgOp->getRegion (0 );
857+ rewriter.cloneRegionBefore (originalRegion, fusedRegion,
858+ fusedRegion.begin ());
859+
860+ // Update the index accesses after the expansion.
861+ updateExpandedGenericOpRegion (rewriter, linalgOp.getLoc (), fusedRegion,
862+ expansionInfo);
863+ return fused;
864+ });
865+ }
817866
818867// / Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
819868// / and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
@@ -919,51 +968,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
919968 }
920969 }
921970
922- // The iterator types of the expanded op are all parallel.
923- SmallVector<utils::IteratorType> iteratorTypes (
924- expansionInfo.getExpandedOpNumDims (), utils::IteratorType::parallel);
925- for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
926- for (auto j : expansionInfo.getExpandedDims (i))
927- iteratorTypes[j] = type;
928-
929971 TypeRange resultTypes = ValueRange (outputs).getTypes ();
930- Operation *fusedOp;
931-
932- TypeSwitch<Operation *>(linalgOp.getOperation ())
933- .Case <GenericOp>([&](GenericOp op) {
934- fusedOp = rewriter.create <GenericOp>(
935- linalgOp.getLoc (), resultTypes, expandedOpOperands, outputs,
936- expandedOpIndexingMaps, iteratorTypes);
937- Region &fusedRegion = fusedOp->getRegion (0 );
938- Region &originalRegion = linalgOp->getRegion (0 );
939- rewriter.cloneRegionBefore (originalRegion, fusedRegion,
940- fusedRegion.begin ());
941-
942- // Update the index accesses after the expansion.
943- updateExpandedGenericOpRegion (rewriter, loc, fusedRegion,
944- expansionInfo);
945- })
946- .Case <TransposeOp>([&](TransposeOp op) {
947- SmallVector<ReassociationIndices> reassociation =
948- isExpanding ? expandingReshapeOp.getReassociationIndices ()
949- : collapsingReshapeOp.getReassociationIndices ();
950- applyPermutationToVector (reassociation, op.getPermutation ());
951- SmallVector<int64_t > newPerm;
952- for (auto reassoc : reassociation) {
953- for (auto dim : reassoc) {
954- newPerm.push_back (dim);
955- }
956- }
957- fusedOp = rewriter.create <TransposeOp>(
958- linalgOp.getLoc (), expandedOpOperands[0 ], outputs[0 ], newPerm);
959- })
960- // All other expandable linalg ops that are not generic or transpose can
961- // be cloned with the expanded input and output operands.
962- .Default ([&](Operation *op) {
963- fusedOp = clone (
964- rewriter, linalgOp, resultTypes,
965- llvm::to_vector (llvm::concat<Value>(expandedOpOperands, outputs)));
966- });
972+ SmallVector<ReassociationIndices> reassociationBeforeExpansion =
973+ isExpanding ? expandingReshapeOp.getReassociationIndices ()
974+ : collapsingReshapeOp.getReassociationIndices ();
975+ Operation *fusedOp = createFusedOpForReshapeByExpansion (
976+ rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
977+ expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
967978 // Reshape the result values to their original shape if this is a collapsing
968979 // reshape folded into its consumer.
969980 SmallVector<Value> resultVals;
0 commit comments