Skip to content

Commit 963e2e9

Browse files
authored
[CodeGen] Do not fuse parallel ops if they directly write to destination. (iree-org#21837)
The pass was mainly introduced for softmax dispatch, so it's okay to limit the scope of the fusion. If we unconditionally fuse the ops, it may result in independent compute ops. In this context, there are more than one root ops; codegen does not expect the case. It is basically a result of two dispatches that get formed into a single dispatch. Fixes iree-org#21836 It is a step towards iree-org#21828. There are other issues about domination in some dispatches. --------- Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 1c54e4d commit 963e2e9

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "iree/compiler/Codegen/Common/Passes.h"
88
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
9+
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
10+
#include "llvm/Support/Casting.h"
911
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1012
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1113

@@ -25,6 +27,11 @@ static bool isScalarOrTensorOfSizeOne(Type t) {
2527
return t.isIntOrIndexOrFloat();
2628
}
2729

30+
static bool hasDirectWriteResult(Operation *op) {
31+
return llvm::any_of(op->getUsers(),
32+
llvm::IsaPred<IREE::TensorExt::DispatchTensorStoreOp>);
33+
}
34+
2835
/// Rematerialize all parallel elementwise operations into its users within a
2936
/// `flow.dispatch.region`.
3037
struct RematerializeParallelOpsPattern
@@ -51,6 +58,9 @@ struct RematerializeParallelOpsPattern
5158
if (producer && hasExternalCapture(producer)) {
5259
continue;
5360
}
61+
if (producer && hasDirectWriteResult(producer)) {
62+
continue;
63+
}
5464
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
5565
linalg::fuseElementwiseOps(rewriter, &opOperand);
5666
if (succeeded(fusionResult)) {

compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,46 @@ func.func @no_external_capture_fusion(%arg0: tensor<4096x64xi64>, %arg1: tensor<
169169
// CHECK-LABEL: func @no_external_capture_fusion(
170170
// CHECK: linalg.generic
171171
// CHECK: linalg.generic
172+
173+
// -----
174+
175+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
176+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
177+
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
178+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
179+
#pipeline_layout = #hal.pipeline.layout<
180+
bindings = [
181+
#hal.pipeline.binding<storage_buffer, Indirect>,
182+
#hal.pipeline.binding<storage_buffer, Indirect>
183+
]>
184+
func.func @producer_has_direct_write(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5xf32>) {
185+
%cst = arith.constant 0.000000e+00 : f32
186+
%c0 = arith.constant 0 : index
187+
%c64 = arith.constant 64 : index
188+
%c128 = arith.constant 128 : index
189+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c64) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x5xf32>>
190+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c128) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x4x5xf32>>
191+
%2 = tensor.empty() : tensor<3x5xf32>
192+
%3 = tensor.empty() : tensor<3x4x5xf32>
193+
%4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<3x5xf32>) -> tensor<3x5xf32>
194+
%5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<3x4x5xf32>, tensor<3x5xf32>) outs(%3 : tensor<3x4x5xf32>) {
195+
^bb0(%in: f32, %in_0: f32, %out: f32):
196+
%7 = arith.subf %in, %in_0 : f32
197+
linalg.yield %7 : f32
198+
} -> tensor<3x4x5xf32>
199+
%6 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%5 : tensor<3x4x5xf32>) outs(%4 : tensor<3x5xf32>) {
200+
^bb0(%in: f32, %out: f32):
201+
%7 = math.exp %in : f32
202+
%8 = arith.addf %7, %out : f32
203+
linalg.yield %8 : f32
204+
} -> tensor<3x5xf32>
205+
iree_tensor_ext.dispatch.tensor.store %6, %0, offsets = [0, 0], sizes = [3, 5], strides = [1, 1] : tensor<3x5xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x5xf32>>
206+
iree_tensor_ext.dispatch.tensor.store %5, %1, offsets = [0, 0, 0], sizes = [3, 4, 5], strides = [1, 1, 1] : tensor<3x4x5xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x4x5xf32>>
207+
return
208+
}
209+
// CHECK-LABEL: func.func @producer_has_direct_write
210+
// CHECK: %[[ELEM:.+]] = linalg.generic
211+
// CHECK: %[[REDUCTION:.+]] = linalg.generic
212+
// CHECK-SAME: ins(%[[ELEM]]
213+
// CHECK-DAG: iree_tensor_ext.dispatch.tensor.store %[[REDUCTION]]
214+
// CHECK-DAG: iree_tensor_ext.dispatch.tensor.store %[[ELEM]]

0 commit comments

Comments
 (0)