Skip to content

Commit 3526164

Browse files
Address reviewer comments
Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 620f2a8 commit 3526164

File tree

2 files changed

+96
-48
lines changed

2 files changed

+96
-48
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,55 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
814814
}
815815
return success();
816816
}
817+
// Create an expanded fused op that retains the name for certain ops
818+
// such as fill, copy and transpose and produce a generic op for
819+
// rest of linalg ops.
820+
Operation *createFusedOpForReshapeByExpansion(
821+
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
822+
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
823+
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
824+
SmallVector<ReassociationIndices> reassociation) {
825+
826+
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
827+
.Case<TransposeOp>([&](TransposeOp op) {
828+
applyPermutationToVector(reassociation, op.getPermutation());
829+
SmallVector<int64_t> newPerm;
830+
for (auto reassoc : reassociation) {
831+
for (auto dim : reassoc) {
832+
newPerm.push_back(dim);
833+
}
834+
}
835+
return rewriter.create<TransposeOp>(
836+
linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
837+
})
838+
.Case<FillOp, CopyOp>([&](Operation *op) {
839+
return clone(rewriter, linalgOp, resultTypes,
840+
llvm::to_vector(llvm::concat<Value>(
841+
llvm::to_vector(expandedOpOperands),
842+
llvm::to_vector(outputs))));
843+
})
844+
.Default([&](Operation *op) {
845+
// The iterator types of the expanded op are all parallel.
846+
SmallVector<utils::IteratorType> iteratorTypes(
847+
expansionInfo.getExpandedOpNumDims(),
848+
utils::IteratorType::parallel);
849+
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
850+
for (auto j : expansionInfo.getExpandedDims(i))
851+
iteratorTypes[j] = type;
852+
Operation *fused = rewriter.create<GenericOp>(
853+
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
854+
expandedOpIndexingMaps, iteratorTypes);
855+
Region &fusedRegion = fused->getRegion(0);
856+
Region &originalRegion = linalgOp->getRegion(0);
857+
rewriter.cloneRegionBefore(originalRegion, fusedRegion,
858+
fusedRegion.begin());
859+
860+
// Update the index accesses after the expansion.
861+
updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
862+
expansionInfo);
863+
return fused;
864+
});
865+
}
817866

818867
/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
819868
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
@@ -919,51 +968,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
919968
}
920969
}
921970

922-
// The iterator types of the expanded op are all parallel.
923-
SmallVector<utils::IteratorType> iteratorTypes(
924-
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
925-
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
926-
for (auto j : expansionInfo.getExpandedDims(i))
927-
iteratorTypes[j] = type;
928-
929971
TypeRange resultTypes = ValueRange(outputs).getTypes();
930-
Operation *fusedOp;
931-
932-
TypeSwitch<Operation *>(linalgOp.getOperation())
933-
.Case<GenericOp>([&](GenericOp op) {
934-
fusedOp = rewriter.create<GenericOp>(
935-
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
936-
expandedOpIndexingMaps, iteratorTypes);
937-
Region &fusedRegion = fusedOp->getRegion(0);
938-
Region &originalRegion = linalgOp->getRegion(0);
939-
rewriter.cloneRegionBefore(originalRegion, fusedRegion,
940-
fusedRegion.begin());
941-
942-
// Update the index accesses after the expansion.
943-
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion,
944-
expansionInfo);
945-
})
946-
.Case<TransposeOp>([&](TransposeOp op) {
947-
SmallVector<ReassociationIndices> reassociation =
948-
isExpanding ? expandingReshapeOp.getReassociationIndices()
949-
: collapsingReshapeOp.getReassociationIndices();
950-
applyPermutationToVector(reassociation, op.getPermutation());
951-
SmallVector<int64_t> newPerm;
952-
for (auto reassoc : reassociation) {
953-
for (auto dim : reassoc) {
954-
newPerm.push_back(dim);
955-
}
956-
}
957-
fusedOp = rewriter.create<TransposeOp>(
958-
linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
959-
})
960-
// All other expandable linalg ops that are not generic or transpose can
961-
// be cloned with the expanded input and output operands.
962-
.Default([&](Operation *op) {
963-
fusedOp = clone(
964-
rewriter, linalgOp, resultTypes,
965-
llvm::to_vector(llvm::concat<Value>(expandedOpOperands, outputs)));
966-
});
972+
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
973+
isExpanding ? expandingReshapeOp.getReassociationIndices()
974+
: collapsingReshapeOp.getReassociationIndices();
975+
Operation *fusedOp = createFusedOpForReshapeByExpansion(
976+
rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
977+
expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
967978
// Reshape the result values to their original shape if this is a collapsing
968979
// reshape folded into its consumer.
969980
SmallVector<Value> resultVals;

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
753753
return %1 : tensor<?x?x4x5xf32>
754754
}
755755

756+
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
756757
// CHECK: func @linalg_add_reshape_consumer_fusion
757758
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
758759
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -773,7 +774,9 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
773774
// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
774775
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
775776
// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
776-
// CHECK: %[[T4:.+]] = linalg.add
777+
// CHECK: %[[T4:.+]] = linalg.generic
778+
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
779+
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
777780
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
778781
// CHECK-SAME: outs(%[[T3]] : tensor<?x?x4x5xf32>)
779782
// CHECK: return %[[T4]] : tensor<?x?x4x5xf32>
@@ -792,6 +795,7 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
792795
return %1 : tensor<?x?xf32>
793796
}
794797

798+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
795799
// CHECK: func @linalg_add_reshape_producer_fusion
796800
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
797801
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -810,7 +814,9 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
810814
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
811815
// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
812816
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
813-
// CHECK: %[[T3:.+]] = linalg.add
817+
// CHECK: %[[T3:.+]] = linalg.generic
818+
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
819+
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
814820
// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
815821
// CHECK-SAME: outs(%[[T2]] : tensor<?x7x?x8xf32>)
816822
// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
@@ -820,6 +826,39 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
820826

821827
// -----
822828

829+
func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
830+
%arg1 : tensor<?x?xf32>) ->
831+
tensor<?x?xf32>
832+
{
833+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
834+
tensor<?x7x?x8xf32> into tensor<?x?xf32>
835+
%1 = linalg.copy ins(%0 : tensor<?x?xf32>)
836+
outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
837+
return %1 : tensor<?x?xf32>
838+
}
839+
840+
// CHECK: func @linalg_copy_reshape_producer_fusion
841+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
842+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
843+
// CHECK: %[[C8:.+]] = arith.constant 8 : index
844+
// CHECK: %[[C7:.+]] = arith.constant 7 : index
845+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
846+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
847+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
848+
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
849+
// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
850+
// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
851+
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
852+
// CHECK: %[[T2:.+]] = linalg.copy
853+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
854+
// CHECK-SAME: outs(%[[T1]] : tensor<?x7x?x8xf32>)
855+
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
856+
// CHECK-SAME: [0, 1], [2, 3]
857+
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
858+
// CHECK: return %[[T3]]
859+
860+
// -----
861+
823862
func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
824863
%arg1 : tensor<?x?xf32>) ->
825864
tensor<?x?xf32>
@@ -852,8 +891,6 @@ func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
852891
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
853892
// CHECK: return %[[T3]]
854893

855-
856-
857894
// -----
858895

859896
func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {

0 commit comments

Comments
 (0)