Skip to content

Commit 0e75e3e

Browse files
Further refactor
Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 3526164 commit 0e75e3e

File tree

1 file changed

+53
-31
lines changed

1 file changed

+53
-31
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -814,26 +814,64 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
814814
}
815815
return success();
816816
}
817+
818+
// Create an expanded transpose op.
819+
static Operation *
820+
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
821+
SmallVector<ReassociationIndices> reassociation,
822+
Value expandedInput, Value output) {
823+
applyPermutationToVector(reassociation, transposeOp.getPermutation());
824+
SmallVector<int64_t> newPerm;
825+
for (auto reassoc : reassociation) {
826+
for (auto dim : reassoc) {
827+
newPerm.push_back(dim);
828+
}
829+
}
830+
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
831+
output, newPerm);
832+
}
833+
834+
// Create an expanded generic op.
835+
static Operation *createExpandedGenericOp(
836+
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
837+
ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
838+
ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
839+
// The iterator types of the expanded op are all parallel.
840+
SmallVector<utils::IteratorType> iteratorTypes(
841+
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
842+
843+
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
844+
for (auto j : expansionInfo.getExpandedDims(i))
845+
iteratorTypes[j] = type;
846+
847+
Operation *fused = rewriter.create<GenericOp>(
848+
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
849+
expandedOpIndexingMaps, iteratorTypes);
850+
851+
Region &fusedRegion = fused->getRegion(0);
852+
Region &originalRegion = linalgOp->getRegion(0);
853+
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
854+
855+
// Update the index accesses after the expansion.
856+
updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
857+
expansionInfo);
858+
859+
return fused;
860+
}
861+
817862
// Create an expanded fused op that retains the name for certain ops
818863
// such as fill, copy and transpose and produce a generic op for
819864
// rest of linalg ops.
820-
Operation *createFusedOpForReshapeByExpansion(
865+
static Operation *createExpandedOp(
821866
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
822867
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
823868
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
824869
SmallVector<ReassociationIndices> reassociation) {
825870

826871
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
827-
.Case<TransposeOp>([&](TransposeOp op) {
828-
applyPermutationToVector(reassociation, op.getPermutation());
829-
SmallVector<int64_t> newPerm;
830-
for (auto reassoc : reassociation) {
831-
for (auto dim : reassoc) {
832-
newPerm.push_back(dim);
833-
}
834-
}
835-
return rewriter.create<TransposeOp>(
836-
linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
872+
.Case<TransposeOp>([&](TransposeOp transposeOp) {
873+
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
874+
expandedOpOperands[0], outputs[0]);
837875
})
838876
.Case<FillOp, CopyOp>([&](Operation *op) {
839877
return clone(rewriter, linalgOp, resultTypes,
@@ -842,25 +880,9 @@ Operation *createFusedOpForReshapeByExpansion(
842880
llvm::to_vector(outputs))));
843881
})
844882
.Default([&](Operation *op) {
845-
// The iterator types of the expanded op are all parallel.
846-
SmallVector<utils::IteratorType> iteratorTypes(
847-
expansionInfo.getExpandedOpNumDims(),
848-
utils::IteratorType::parallel);
849-
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
850-
for (auto j : expansionInfo.getExpandedDims(i))
851-
iteratorTypes[j] = type;
852-
Operation *fused = rewriter.create<GenericOp>(
853-
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
854-
expandedOpIndexingMaps, iteratorTypes);
855-
Region &fusedRegion = fused->getRegion(0);
856-
Region &originalRegion = linalgOp->getRegion(0);
857-
rewriter.cloneRegionBefore(originalRegion, fusedRegion,
858-
fusedRegion.begin());
859-
860-
// Update the index accesses after the expansion.
861-
updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
862-
expansionInfo);
863-
return fused;
883+
return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
884+
expandedOpOperands, outputs,
885+
expansionInfo, expandedOpIndexingMaps);
864886
});
865887
}
866888

@@ -972,7 +994,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
972994
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
973995
isExpanding ? expandingReshapeOp.getReassociationIndices()
974996
: collapsingReshapeOp.getReassociationIndices();
975-
Operation *fusedOp = createFusedOpForReshapeByExpansion(
997+
Operation *fusedOp = createExpandedOp(
976998
rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
977999
expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
9781000
// Reshape the result values to their original shape if this is a collapsing

0 commit comments

Comments
 (0)