Skip to content

Commit 75d238a

Browse files
Use expansioninfo to get output reassociation
Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent d8c467e commit 75d238a

File tree

1 file changed

+18
-39
lines changed

1 file changed

+18
-39
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -811,24 +811,12 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
811811
}
812812

813813
// Create an expanded transpose op.
814-
// For sinking a collapse : transpose(collapse_shape),
815-
// all expanded groups are permuted together. We just permute the reassocation
816-
// map of the collapse and flatten it. For example,
817-
//
818-
// reassociation_map = [[0], [1, 2, 3], [4, 5]]
819-
// permutation = [2, 0, 1]
820-
//
821-
// Becomes
822-
//
823-
// permutation = [4, 5, 0 , 1, 2, 3]
824-
//
825-
// For bubbling an expand : expand_shape(transpose),
826814
// the reassociation map is already permuted hence we inverse permute and then
827815
// flatten it. Then we inverse permute it again to get the final expanded
828816
// transpose permutation. For example,
829817
//
830818
// permutation = [2, 0, 1]
831-
// reassociation_map = [[0, 1], [2], [3, 4, 5]]
819+
// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
832820
//
833821
// inverse permutation = [1, 2, 0]
834822
// applied to reassocation_map and then flattened becomes
@@ -839,25 +827,19 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
839827
//
840828
// permutation=[4, 5, 0, 1, 2, 3]
841829

842-
static Operation *
843-
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
844-
SmallVector<ReassociationIndices> reassociation,
845-
Value expandedInput, Value output, bool isExpanding) {
846-
ArrayRef<int64_t> permutation =
847-
isExpanding ? invertPermutationVector(transposeOp.getPermutation())
848-
: transposeOp.getPermutation();
849-
applyPermutationToVector(reassociation, permutation);
830+
static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
831+
TransposeOp transposeOp,
832+
Value expandedInput, Value output,
833+
ExpansionInfo &expansionInfo) {
850834
SmallVector<int64_t> newPerm;
851-
for (const auto &reassoc : reassociation) {
852-
for (auto dim : reassoc) {
835+
for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
836+
auto reassoc = expansionInfo.getExpandedDims(perm);
837+
for (int64_t dim : reassoc) {
853838
newPerm.push_back(dim);
854839
}
855840
}
856-
if (isExpanding) {
857-
newPerm = invertPermutationVector(newPerm);
858-
}
859841
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
860-
output, newPerm);
842+
output, invertPermutationVector(newPerm));
861843
}
862844

863845
// Create an expanded generic op.
@@ -891,17 +873,18 @@ static Operation *createExpandedGenericOp(
891873
// Create an expanded fused op that retains the name for certain ops
892874
// such as fill, copy and transpose and produce a generic op for
893875
// rest of linalg ops.
894-
static Operation *createExpandedOp(
895-
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
896-
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
897-
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
898-
SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
876+
static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
877+
TypeRange resultTypes,
878+
ArrayRef<Value> expandedOpOperands,
879+
ArrayRef<Value> outputs,
880+
ArrayRef<AffineMap> expandedOpIndexingMaps,
881+
ExpansionInfo &expansionInfo) {
899882

900883
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
901884
.Case<TransposeOp>([&](TransposeOp transposeOp) {
902-
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
885+
return createExpandedTransposeOp(rewriter, transposeOp,
903886
expandedOpOperands[0], outputs[0],
904-
isExpanding);
887+
expansionInfo);
905888
})
906889
.Case<FillOp, CopyOp>([&](Operation *op) {
907890
return clone(rewriter, linalgOp, resultTypes,
@@ -1021,13 +1004,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
10211004
}
10221005

10231006
TypeRange resultTypes = ValueRange(outputs).getTypes();
1024-
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
1025-
isExpanding ? expandingReshapeOp.getReassociationIndices()
1026-
: collapsingReshapeOp.getReassociationIndices();
10271007
Operation *fusedOp =
10281008
createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
1029-
outputs, expandedOpIndexingMaps, expansionInfo,
1030-
reassociationBeforeExpansion, isExpanding);
1009+
outputs, expandedOpIndexingMaps, expansionInfo);
10311010
// Reshape the result values to their original shape if this is a collapsing
10321011
// reshape folded into its consumer.
10331012
SmallVector<Value> resultVals;

0 commit comments

Comments
 (0)