@@ -814,26 +814,64 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
814814 }
815815 return success ();
816816}
817+
818+ // Create an expanded transpose op.
819+ static Operation *
820+ createExpandedTransposeOp (PatternRewriter &rewriter, TransposeOp transposeOp,
821+ SmallVector<ReassociationIndices> reassociation,
822+ Value expandedInput, Value output) {
823+ applyPermutationToVector (reassociation, transposeOp.getPermutation ());
824+ SmallVector<int64_t > newPerm;
825+ for (auto reassoc : reassociation) {
826+ for (auto dim : reassoc) {
827+ newPerm.push_back (dim);
828+ }
829+ }
830+ return rewriter.create <TransposeOp>(transposeOp.getLoc (), expandedInput,
831+ output, newPerm);
832+ }
833+
834+ // Create an expanded generic op.
835+ static Operation *createExpandedGenericOp (
836+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
837+ ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
838+ ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
839+ // The iterator types of the expanded op are all parallel.
840+ SmallVector<utils::IteratorType> iteratorTypes (
841+ expansionInfo.getExpandedOpNumDims (), utils::IteratorType::parallel);
842+
843+ for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
844+ for (auto j : expansionInfo.getExpandedDims (i))
845+ iteratorTypes[j] = type;
846+
847+ Operation *fused = rewriter.create <GenericOp>(
848+ linalgOp.getLoc (), resultTypes, expandedOpOperands, outputs,
849+ expandedOpIndexingMaps, iteratorTypes);
850+
851+ Region &fusedRegion = fused->getRegion (0 );
852+ Region &originalRegion = linalgOp->getRegion (0 );
853+ rewriter.cloneRegionBefore (originalRegion, fusedRegion, fusedRegion.begin ());
854+
855+ // Update the index accesses after the expansion.
856+ updateExpandedGenericOpRegion (rewriter, linalgOp.getLoc (), fusedRegion,
857+ expansionInfo);
858+
859+ return fused;
860+ }
861+
817862// Create an expanded fused op that retains the name for certain ops
818863// such as fill, copy and transpose and produce a generic op for
819864// rest of linalg ops.
820- Operation *createFusedOpForReshapeByExpansion (
865+ static Operation *createExpandedOp (
821866 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
822867 ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
823868 ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
824869 SmallVector<ReassociationIndices> reassociation) {
825870
826871 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation ())
827- .Case <TransposeOp>([&](TransposeOp op) {
828- applyPermutationToVector (reassociation, op.getPermutation ());
829- SmallVector<int64_t > newPerm;
830- for (auto reassoc : reassociation) {
831- for (auto dim : reassoc) {
832- newPerm.push_back (dim);
833- }
834- }
835- return rewriter.create <TransposeOp>(
836- linalgOp.getLoc (), expandedOpOperands[0 ], outputs[0 ], newPerm);
872+ .Case <TransposeOp>([&](TransposeOp transposeOp) {
873+ return createExpandedTransposeOp (rewriter, transposeOp, reassociation,
874+ expandedOpOperands[0 ], outputs[0 ]);
837875 })
838876 .Case <FillOp, CopyOp>([&](Operation *op) {
839877 return clone (rewriter, linalgOp, resultTypes,
@@ -842,25 +880,9 @@ Operation *createFusedOpForReshapeByExpansion(
842880 llvm::to_vector (outputs))));
843881 })
844882 .Default ([&](Operation *op) {
845- // The iterator types of the expanded op are all parallel.
846- SmallVector<utils::IteratorType> iteratorTypes (
847- expansionInfo.getExpandedOpNumDims (),
848- utils::IteratorType::parallel);
849- for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
850- for (auto j : expansionInfo.getExpandedDims (i))
851- iteratorTypes[j] = type;
852- Operation *fused = rewriter.create <GenericOp>(
853- linalgOp.getLoc (), resultTypes, expandedOpOperands, outputs,
854- expandedOpIndexingMaps, iteratorTypes);
855- Region &fusedRegion = fused->getRegion (0 );
856- Region &originalRegion = linalgOp->getRegion (0 );
857- rewriter.cloneRegionBefore (originalRegion, fusedRegion,
858- fusedRegion.begin ());
859-
860- // Update the index accesses after the expansion.
861- updateExpandedGenericOpRegion (rewriter, linalgOp.getLoc (), fusedRegion,
862- expansionInfo);
863- return fused;
883+ return createExpandedGenericOp (rewriter, linalgOp, resultTypes,
884+ expandedOpOperands, outputs,
885+ expansionInfo, expandedOpIndexingMaps);
864886 });
865887}
866888
@@ -972,7 +994,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
972994 SmallVector<ReassociationIndices> reassociationBeforeExpansion =
973995 isExpanding ? expandingReshapeOp.getReassociationIndices ()
974996 : collapsingReshapeOp.getReassociationIndices ();
975- Operation *fusedOp = createFusedOpForReshapeByExpansion (
997+ Operation *fusedOp = createExpandedOp (
976998 rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
977999 expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
9781000 // Reshape the result values to their original shape if this is a collapsing
0 commit comments