@@ -815,6 +815,77 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
815815 return success ();
816816}
817817
818+ // Create an expanded transpose op.
819+ static Operation *
820+ createExpandedTransposeOp (PatternRewriter &rewriter, TransposeOp transposeOp,
821+ SmallVector<ReassociationIndices> reassociation,
822+ Value expandedInput, Value output) {
823+ applyPermutationToVector (reassociation, transposeOp.getPermutation ());
824+ SmallVector<int64_t > newPerm;
825+ for (auto reassoc : reassociation) {
826+ for (auto dim : reassoc) {
827+ newPerm.push_back (dim);
828+ }
829+ }
830+ return rewriter.create <TransposeOp>(transposeOp.getLoc (), expandedInput,
831+ output, newPerm);
832+ }
833+
834+ // Create an expanded generic op.
835+ static Operation *createExpandedGenericOp (
836+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
837+ ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
838+ ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
839+ // The iterator types of the expanded op are all parallel.
840+ SmallVector<utils::IteratorType> iteratorTypes (
841+ expansionInfo.getExpandedOpNumDims (), utils::IteratorType::parallel);
842+
843+ for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
844+ for (auto j : expansionInfo.getExpandedDims (i))
845+ iteratorTypes[j] = type;
846+
847+ Operation *fused = rewriter.create <GenericOp>(
848+ linalgOp.getLoc (), resultTypes, expandedOpOperands, outputs,
849+ expandedOpIndexingMaps, iteratorTypes);
850+
851+ Region &fusedRegion = fused->getRegion (0 );
852+ Region &originalRegion = linalgOp->getRegion (0 );
853+ rewriter.cloneRegionBefore (originalRegion, fusedRegion, fusedRegion.begin ());
854+
855+ // Update the index accesses after the expansion.
856+ updateExpandedGenericOpRegion (rewriter, linalgOp.getLoc (), fusedRegion,
857+ expansionInfo);
858+
859+ return fused;
860+ }
861+
862+ // Create an expanded fused op that retains the name for certain ops
863+ // such as fill, copy and transpose and produce a generic op for
864+ // rest of linalg ops.
865+ static Operation *createExpandedOp (
866+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
867+ ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
868+ ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
869+ SmallVector<ReassociationIndices> reassociation) {
870+
871+ return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation ())
872+ .Case <TransposeOp>([&](TransposeOp transposeOp) {
873+ return createExpandedTransposeOp (rewriter, transposeOp, reassociation,
874+ expandedOpOperands[0 ], outputs[0 ]);
875+ })
876+ .Case <FillOp, CopyOp>([&](Operation *op) {
877+ return clone (rewriter, linalgOp, resultTypes,
878+ llvm::to_vector (llvm::concat<Value>(
879+ llvm::to_vector (expandedOpOperands),
880+ llvm::to_vector (outputs))));
881+ })
882+ .Default ([&](Operation *op) {
883+ return createExpandedGenericOp (rewriter, linalgOp, resultTypes,
884+ expandedOpOperands, outputs,
885+ expansionInfo, expandedOpIndexingMaps);
886+ });
887+ }
888+
818889// / Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
819890// / and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
820891// / that those conditions have been satisfied.
@@ -919,25 +990,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
919990 }
920991 }
921992
922- // The iterator types of the expanded op are all parallel.
923- SmallVector<utils::IteratorType> iteratorTypes (
924- expansionInfo.getExpandedOpNumDims (), utils::IteratorType::parallel);
925- for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
926- for (auto j : expansionInfo.getExpandedDims (i))
927- iteratorTypes[j] = type;
928-
929993 TypeRange resultTypes = ValueRange (outputs).getTypes ();
930- auto fusedOp =
931- rewriter.create <GenericOp>(linalgOp.getLoc (), resultTypes,
932- /* inputs=*/ expandedOpOperands, outputs,
933- expandedOpIndexingMaps, iteratorTypes);
934- Region &fusedRegion = fusedOp->getRegion (0 );
935- Region &originalRegion = linalgOp->getRegion (0 );
936- rewriter.cloneRegionBefore (originalRegion, fusedRegion, fusedRegion.begin ());
937-
938- // Update the index accesses after the expansion.
939- updateExpandedGenericOpRegion (rewriter, loc, fusedRegion, expansionInfo);
940-
994+ SmallVector<ReassociationIndices> reassociationBeforeExpansion =
995+ isExpanding ? expandingReshapeOp.getReassociationIndices ()
996+ : collapsingReshapeOp.getReassociationIndices ();
997+ Operation *fusedOp = createExpandedOp (
998+ rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
999+ expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
9411000 // Reshape the result values to their original shape if this is a collapsing
9421001 // reshape folded into its consumer.
9431002 SmallVector<Value> resultVals;
0 commit comments