Skip to content

Commit e43555e

Browse files
authored
[CodeGen] Fix gather fusion on vector distribute path (iree-org#21117)
Don't attach lowering config to the gather operation. Let it fuse with the consumer operation. This eliminates the creation of a temporary buffer. The change restricts element-wise fusion in the case of softmax with a gather-like operation. Fixes iree-org#21107
1 parent a073601 commit e43555e

File tree

7 files changed

+116
-42
lines changed

7 files changed

+116
-42
lines changed

compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ struct FuseElementWiseGenericOps : public OpRewritePattern<linalg::GenericOp> {
4545
for (OpOperand &opOperand : genericOp->getOpOperands()) {
4646
if (!linalg::areElementwiseOpsFusable(&opOperand))
4747
continue;
48-
48+
// Don't fuse if it has external capture. For e.g., the gather like
49+
// payload operation like 'tensor.extract' would be cloned in
50+
// every consumer op, which is not what we want.
51+
auto producer = opOperand.get().getDefiningOp<linalg::GenericOp>();
52+
if (producer && hasExternalCapture(producer)) {
53+
continue;
54+
}
4955
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
5056
linalg::fuseElementwiseOps(rewriter, &opOperand);
5157
if (succeeded(fusionResult)) {

compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,47 +25,6 @@ static bool isScalarOrTensorOfSizeOne(Type t) {
2525
return t.isIntOrIndexOrFloat();
2626
}
2727

28-
/// This function checks whether the `genericOp` has any external captures,
29-
/// i.e., whether it uses any values that are defined outside of its body.
30-
/// %10 = linalg.generic {indexing_maps = [#map, #map],
31-
/// iterator_types = ["parallel", "parallel"]}
32-
/// ins(%5 : tensor<4096x64xi64>) outs(%9 : tensor<4096x64xf16>) {
33-
/// ^bb0(%in: i64, %out: f16):
34-
/// %14 = linalg.index 0 : index
35-
/// %15 = arith.index_cast %in : i64 to index
36-
/// %extracted = tensor.extract %4[%14, %15] : tensor<4096x64xf16>
37-
/// linalg.yield %extracted : f16
38-
/// } -> tensor<4096x64xf16>
39-
/// Here %4 is an external capture used via tensor.extract inside
40-
/// linalg.generic hence the above `genericOp` has an external capture.
41-
static bool hasExternalCapture(linalg::GenericOp genericOp) {
42-
Block &body = genericOp.getRegion().front();
43-
for (Operation &op : body.getOperations()) {
44-
for (Value operand : op.getOperands()) {
45-
if (auto bArg = dyn_cast<BlockArgument>(operand)) {
46-
// Check whether the operand lies in the same block.
47-
if (bArg.getOwner() == &body) {
48-
continue;
49-
}
50-
return true;
51-
}
52-
Operation *defOp = operand.getDefiningOp();
53-
// Scalar constant is allowed.
54-
if (defOp && defOp->hasTrait<mlir::OpTrait::ConstantLike>()) {
55-
Type type = operand.getType();
56-
if (type.isIntOrFloat() || type.isIndex()) {
57-
continue;
58-
}
59-
}
60-
// If defining op is not inside the block, it’s an external value.
61-
if (!defOp || defOp->getBlock() != &body) {
62-
return true;
63-
}
64-
}
65-
}
66-
return false; // All operands are locally defined or block arguments.
67-
}
68-
6928
/// Rematerialize all parallel elementwise operations into its users within a
7029
/// `flow.dispatch.region`.
7130
struct RematerializeParallelOpsPattern

compiler/src/iree/compiler/Codegen/Common/test/decompose_softmax.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,24 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
8282
// CHECK-NO-FUSE: } -> tensor<2x16x32xf32>
8383
// CHECK-NO-FUSE: return %[[D7]] : tensor<2x16x32xf32>
8484
// CHECK-NO-FUSE: }
85+
86+
// -----
87+
88+
#map = affine_map<(d0, d1) -> (d0, d1)>
89+
func.func @do_not_fuse_gather(%arg0: tensor<4096x64xi64>, %arg1: tensor<4096x64xf32>) -> tensor<4096x64xf32> {
90+
%empty = tensor.empty() : tensor<4096x64xf32>
91+
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4096x64xi64>) outs(%empty : tensor<4096x64xf32>) {
92+
^bb0(%in: i64, %out: f32):
93+
%3 = linalg.index 0 : index
94+
%4 = arith.index_cast %in : i64 to index
95+
%extracted = tensor.extract %arg1[%3, %4] : tensor<4096x64xf32>
96+
linalg.yield %extracted : f32
97+
} -> tensor<4096x64xf32>
98+
%s_empty = tensor.empty() : tensor<4096x64xf32>
99+
%1 = linalg.softmax dimension(1) ins(%0 : tensor<4096x64xf32>) outs(%s_empty: tensor<4096x64xf32>) -> tensor<4096x64xf32>
100+
return %1 : tensor<4096x64xf32>
101+
}
102+
// CHECK-LABEL: func @do_not_fuse_gather(
103+
// CHECK: linalg.generic {{.*}}
104+
// CHECK: tensor.extract {{.*}} : tensor<4096x64xf32>
105+
// CHECK-COUNT-3: linalg.generic

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,11 @@ populateConfigInfo(const llvm::SetVector<linalg::LinalgOp> &computeOps,
544544
// LinalgOp with only parallel dims. This is needed if the op cannot be fused
545545
// with a reduction or introduces new loop dimensions.
546546
auto shouldAttachLoweringConfig = [&](linalg::LinalgOp linalgOp) -> bool {
547+
// If the operation has a gather, we want to fuse it with the
548+
// reduction.
549+
if (hasExternalCapture(cast<linalg::GenericOp>(linalgOp))) {
550+
return false;
551+
}
547552
// If some of the users are in computeOps and some are outside of
548553
// computeOps; attach lowering config, since the op can't be fused.
549554
if (llvm::any_of(linalgOp->getUsers(),

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,43 @@ func.func @test_multiple_stores(%arg0: !iree_tensor_ext.dispatch.tensor<readonly
251251
// CHECK-SAME: subgroup_basis = {{\[}}[1, 16], [0, 1]],
252252
// CHECK-SAME: thread = [0, 4], thread_basis = {{\[}}[1, 64], [0, 1]],
253253
// CHECK-SAME: workgroup = [1, 0]
254+
255+
// -----
256+
257+
#map = affine_map<(d0, d1) -> (d0, d1)>
258+
#map1 = affine_map<(d0, d1) -> (d0)>
259+
// Test to not add lowering to gather like operation.
260+
func.func @test_gather_config(%arg0: !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096xi64>>, %arg1: !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x64xf32>>, %arg2: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4096xf32>>) {
261+
%c2_i64 = arith.constant 2 : i64
262+
%c0 = arith.constant 0 : index
263+
%load1 = iree_tensor_ext.dispatch.tensor.load %arg0, offsets = [0], sizes = [4096], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096xi64>> -> tensor<4096xi64>
264+
%load2 = iree_tensor_ext.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [4096, 64], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x64xf32>> -> tensor<4096x64xf32>
265+
%0 = tensor.empty() : tensor<4096x64xf32>
266+
%1 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel"]} ins(%load1 : tensor<4096xi64>) outs(%0 : tensor<4096x64xf32>) {
267+
^bb0(%in: i64, %out: f32):
268+
%4 = linalg.index 0 : index
269+
%5 = linalg.index 1 : index
270+
%extracted = tensor.extract %load2[%4, %5] : tensor<4096x64xf32>
271+
linalg.yield %extracted : f32
272+
} -> tensor<4096x64xf32>
273+
%2 = tensor.empty() : tensor<4096xf32>
274+
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%1 : tensor<4096x64xf32>) outs(%2 : tensor<4096xf32>) {
275+
^bb0(%in: f32, %out: f32):
276+
%4 = arith.addf %in, %out : f32
277+
linalg.yield %4 : f32
278+
} -> tensor<4096xf32>
279+
iree_tensor_ext.dispatch.tensor.store %3, %arg2, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4096xf32>>
280+
return
281+
}
282+
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64
283+
// CHECK: func.func @test_gather_config
284+
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
285+
// CHECK: linalg.generic
286+
// CHECK-NOT: attrs = {lowering_config = #iree_gpu.lowering_config<{
287+
// CHECK: linalg.yield
288+
// CHECK: linalg.generic
289+
// CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{
290+
// CHECK-SAME: partial_reduction = [0, 64],
291+
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1], [0, 1]],
292+
// CHECK-SAME: thread = [0, 1], thread_basis = {{\[}}[1, 64], [0, 1]],
293+
// CHECK-SAME: workgroup = [1, 0]

compiler/src/iree/compiler/Codegen/Utils/Utils.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,4 +1941,32 @@ bool neverRunsSecondIteration(scf::ForOp op) {
19411941
return isUbUnderStep.value_or(false) && isLbNonNegative.value_or(false);
19421942
}
19431943

1944+
bool hasExternalCapture(linalg::GenericOp genericOp) {
1945+
Block &body = genericOp.getRegion().front();
1946+
for (Operation &op : body.getOperations()) {
1947+
for (Value operand : op.getOperands()) {
1948+
if (auto bArg = dyn_cast<BlockArgument>(operand)) {
1949+
// Check whether the operand lies in the same block.
1950+
if (bArg.getOwner() == &body) {
1951+
continue;
1952+
}
1953+
return true;
1954+
}
1955+
Operation *defOp = operand.getDefiningOp();
1956+
// Scalar constant is allowed.
1957+
if (defOp && defOp->hasTrait<mlir::OpTrait::ConstantLike>()) {
1958+
Type type = operand.getType();
1959+
if (type.isIntOrFloat() || type.isIndex()) {
1960+
continue;
1961+
}
1962+
}
1963+
// If defining op is not inside the block, it’s an external value.
1964+
if (!defOp || defOp->getBlock() != &body) {
1965+
return true;
1966+
}
1967+
}
1968+
}
1969+
return false; // All operands are locally defined or block arguments.
1970+
}
1971+
19441972
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Utils/Utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,21 @@ bool alwaysRunsFirstIteration(scf::ForOp op);
331331
/// the ForOp.
332332
bool neverRunsSecondIteration(scf::ForOp op);
333333

334+
/// This function checks whether the `genericOp` has any external captures,
335+
/// i.e., whether it uses any values that are defined outside of its body.
336+
/// %10 = linalg.generic {indexing_maps = [#map, #map],
337+
/// iterator_types = ["parallel", "parallel"]}
338+
/// ins(%5 : tensor<4096x64xi64>) outs(%9 : tensor<4096x64xf16>) {
339+
/// ^bb0(%in: i64, %out: f16):
340+
/// %14 = linalg.index 0 : index
341+
/// %15 = arith.index_cast %in : i64 to index
342+
/// %extracted = tensor.extract %4[%14, %15] : tensor<4096x64xf16>
343+
/// linalg.yield %extracted : f16
344+
/// } -> tensor<4096x64xf16>
345+
/// Here %4 is an external capture used via tensor.extract inside
346+
/// linalg.generic hence the above `genericOp` has an external capture.
347+
bool hasExternalCapture(linalg::GenericOp genericOp);
348+
334349
} // namespace mlir::iree_compiler
335350

336351
#endif // IREE_COMPILER_CODEGEN_UTILS_UTILS_H_

0 commit comments

Comments
 (0)