Skip to content

Commit 7782778

Browse files
authored
[DispatchCreation][DT] Only fuse encodings with element-wise (#20143)
This PR restricts the fusion of set_encoding ops to only dispatches with element-wise linalg ops. We don't have good codegen for fusions with reduction ops on GPU yet, so this will allow using the late materialization data tiling path on GPU while the codegen improves. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent a09be42 commit 7782778

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "iree/compiler/DispatchCreation/Passes.h"
1212
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1313
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
14+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1415
#include "mlir/IR/MLIRContext.h"
1516
#include "mlir/IR/PatternMatch.h"
1617
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -25,10 +26,37 @@ namespace mlir::iree_compiler::DispatchCreation {
2526

2627
namespace {
2728

28-
// Return true if the op is fusable with a SetEncodingOp consumer.
29-
// For now, just check if it is a LinalgOp.
29+
// Return true if the op is fusable with a SetEncodingOp consumer. For now,
30+
// the op's containing dispatch region must not contain any ops other than
31+
// element-wise linalg ops and some tensor ops. This is quite conservative,
32+
// and could be extended to more ops when we are confident that the codegen
33+
// backends can support it.
3034
static bool isFusableWithSetEncoding(Operation *op) {
31-
return isa<linalg::LinalgOp>(op);
35+
auto parentRegion = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
36+
// Make sure the dispatch region has only one block.
37+
if (!llvm::hasSingleElement(parentRegion.getBody())) {
38+
return false;
39+
}
40+
// Check that there are no ops other than reshapes and element-wise linalg
41+
// ops in the dispatch region.
42+
Block &regionBlock = parentRegion.getBody().getBlocks().front();
43+
for (Operation &op : regionBlock.getOperations()) {
44+
if (llvm::none_of(op.getResultTypes(), llvm::IsaPred<ShapedType>)) {
45+
continue;
46+
}
47+
if (isa<tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::EmptyOp>(
48+
op)) {
49+
continue;
50+
}
51+
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
52+
if (!linalgOp) {
53+
return false;
54+
}
55+
if (linalgOp.getNumReductionLoops() != 0) {
56+
return false;
57+
}
58+
}
59+
return true;
3260
}
3361

3462
struct FuseEncodingOpsIntoDispatchRegionsPass

compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,15 @@ module {
6969
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
7070
// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], round_dims_to = array<i64: 32, 32, 32>>
7171
// CHECK-LABEL: @reduction_fusion
72-
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32, #[[$ENCODING]]
72+
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32>)
7373
// CHECK: %[[REDUCTION:.+]] = linalg.generic
74-
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[REDUCTION]]
74+
// CHECK: flow.return %[[REDUCTION]] :
75+
// CHECK: }
76+
// CHECK: %[[DISPATCH_SE:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32, #[[$ENCODING]]>)
77+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH]]
7578
// CHECK: flow.return %[[SET_ENCODING]] :
7679
// CHECK: }
77-
// CHECK: util.return %[[DISPATCH]] : tensor<2x11008x128xf32, #[[$ENCODING]]>
80+
// CHECK: util.return %[[DISPATCH_SE]] : tensor<2x11008x128xf32, #[[$ENCODING]]>
7881

7982
// -----
8083

0 commit comments

Comments
 (0)