Skip to content

Commit cee9306

Browse files
[DispatchCreation] Add pass to move non-fused encodings into dispatches. (#20071)
From the pass that fuses encodings into producers, removes the fallback of moving encodings into their own dispatch in favor of having a separate pass that does this all at once. Signed-off-by: MaheshRavishankar <[email protected]> --------- Signed-off-by: MaheshRavishankar <[email protected]> Signed-off-by: MaheshRavishankar <[email protected]>
1 parent fd157f0 commit cee9306

File tree

10 files changed

+111
-17
lines changed

10 files changed

+111
-17
lines changed

compiler/src/iree/compiler/DispatchCreation/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ iree_compiler_cc_library(
3939
"SplitReduction.cpp",
4040
"TensorPadToTensorInsertSlice.cpp",
4141
"TransposeGenericOps.cpp",
42+
"WrapEncodingOpInDispatchRegion.cpp",
4243
],
4344
hdrs = [
4445
"FusionUtils.h",

compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ iree_cc_library(
4141
"SplitReduction.cpp"
4242
"TensorPadToTensorInsertSlice.cpp"
4343
"TransposeGenericOps.cpp"
44+
"WrapEncodingOpInDispatchRegion.cpp"
4445
DEPS
4546
::PassHeaders
4647
::PassesIncGen

compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ struct FuseEncodingOpsIntoDispatchRegionsPass
6969

7070
SmallVector<IREE::Encoding::SetEncodingOp> encodingOps;
7171
funcOp->walk([&](IREE::Encoding::SetEncodingOp encodingOp) {
72-
encodingOps.push_back(encodingOp);
72+
if (IREE::Flow::isNonNullAndOutsideDispatch(encodingOp)) {
73+
encodingOps.push_back(encodingOp);
74+
}
7375
});
7476

7577
for (IREE::Encoding::SetEncodingOp encodingOp : encodingOps) {
@@ -78,9 +80,6 @@ struct FuseEncodingOpsIntoDispatchRegionsPass
7880
operand.get().getDefiningOp<IREE::Flow::DispatchRegionOp>();
7981
// Nothing to fuse with, so wrap the `encodingOp` in its own dispatch.
8082
if (!producerDispatch) {
81-
if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
82-
return signalPassFailure();
83-
}
8483
continue;
8584
}
8685

@@ -92,17 +91,11 @@ struct FuseEncodingOpsIntoDispatchRegionsPass
9291
auto producerInRegion = dyn_cast<OpResult>(
9392
dispatchReturnOp->getOperand(result.getResultNumber()));
9493
if (!producerInRegion) {
95-
if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
96-
return signalPassFailure();
97-
}
9894
continue;
9995
}
10096

10197
// Place the op in its own dispatch region if fusion is not possible.
10298
if (!isFusableWithSetEncoding(producerInRegion.getOwner())) {
103-
if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
104-
return signalPassFailure();
105-
}
10699
continue;
107100
}
108101
// Fuse the `encodingOp` into the producer dispatch region.

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
259259
.addPass(DispatchCreation::createHoistEncodingOpsPass)
260260
// After SetEncodingOps are hoisted, try to fuse them with their
261261
// producer dispatches to try to hide packing costs.
262-
.addPass(
263-
DispatchCreation::createFuseEncodingOpsIntoDispatchRegionsPass);
262+
.addPass(DispatchCreation::createFuseEncodingOpsIntoDispatchRegionsPass)
263+
.addPass(DispatchCreation::createWrapEncodingOpInDispatchRegionPass);
264264
}
265265
}
266266

compiler/src/iree/compiler/DispatchCreation/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,20 @@ def SetEncodingPass :
290290
];
291291
}
292292

293+
def WrapEncodingOpInDispatchRegionPass :
294+
InterfacePass<"iree-dispatch-creation-wrap-encoding-op-in-dispatch-region",
295+
"mlir::FunctionOpInterface"> {
296+
let summary = "Wrap encoding ops not in dispatches into a dispatch";
297+
let description = [{
298+
This pass is a clean up pass that runs after hoisting + fusion of encoding
299+
ops. It moves any hoisted + non-fused encodings into their own dispatch".
300+
}];
301+
let dependentDialects = [
302+
"IREE::Flow::FlowDialect",
303+
"IREE::Encoding::IREEEncodingDialect",
304+
];
305+
}
306+
293307
//===---------------------------------------------------------------------===//
294308
// Dispatch region to workgroups passes
295309
//===---------------------------------------------------------------------===//
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
8+
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
9+
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
10+
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
11+
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
12+
#include "iree/compiler/DispatchCreation/Passes.h"
13+
14+
namespace mlir::iree_compiler::DispatchCreation {
15+
16+
#define GEN_PASS_DEF_WRAPENCODINGOPINDISPATCHREGIONPASS
17+
#include "iree/compiler/DispatchCreation/Passes.h.inc"
18+
19+
namespace {
20+
21+
struct WrapEncodingOpInDispatchRegionPass
22+
: public impl::WrapEncodingOpInDispatchRegionPassBase<
23+
WrapEncodingOpInDispatchRegionPass> {
24+
25+
void runOnOperation() override;
26+
};
27+
28+
} // namespace
29+
30+
void WrapEncodingOpInDispatchRegionPass::runOnOperation() {
31+
MLIRContext *context = &getContext();
32+
mlir::FunctionOpInterface funcOp = getOperation();
33+
34+
SmallVector<IREE::Encoding::SetEncodingOp> encodingOps;
35+
funcOp->walk([&](IREE::Encoding::SetEncodingOp encodingOp) {
36+
if (IREE::Flow::isNonNullAndOutsideDispatch(encodingOp)) {
37+
encodingOps.push_back(encodingOp);
38+
}
39+
});
40+
41+
IRRewriter rewriter(context);
42+
for (auto encodingOp : encodingOps) {
43+
if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
44+
funcOp.emitOpError("failed to move encoding op into dispatch region");
45+
return signalPassFailure();
46+
}
47+
}
48+
}
49+
50+
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ iree_lit_test_suite(
4848
"split_reduction.mlir",
4949
"tensor_pad_to_tensor_insert_slice.mlir",
5050
"transpose_generic_ops.mlir",
51+
"wrap_encoding_op_in_dispatch_region.mlir",
5152
],
5253
include = ["*.mlir"],
5354
),

compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ iree_lit_test_suite(
4646
"split_reduction.mlir"
4747
"tensor_pad_to_tensor_insert_slice.mlir"
4848
"transpose_generic_ops.mlir"
49+
"wrap_encoding_op_in_dispatch_region.mlir"
4950
TOOLS
5051
FileCheck
5152
iree-opt

compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,8 @@ module {
7373
// CHECK: %[[REDUCTION:.+]] = linalg.generic
7474
// CHECK: flow.return %[[REDUCTION]] :
7575
// CHECK: }
76-
// CHECK: %[[DISPATCH_SE:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32, #[[$ENCODING]]>)
77-
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH]]
78-
// CHECK: flow.return %[[SET_ENCODING]] :
79-
// CHECK: }
80-
// CHECK: util.return %[[DISPATCH_SE]] : tensor<2x11008x128xf32, #[[$ENCODING]]>
76+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH]]
77+
// CHECK: util.return %[[SET_ENCODING]] : tensor<2x11008x128xf32, #[[$ENCODING]]>
8178

8279
// -----
8380

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-wrap-encoding-op-in-dispatch-region))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
2+
3+
util.func @wrap_encoding_op(%arg0 : tensor<?x?xf32>)
4+
-> tensor<?x?xf32, #iree_encoding.testing_encoding<>> {
5+
%0 = iree_encoding.set_encoding %arg0 :
6+
tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.testing_encoding<>>
7+
util.return %0 : tensor<?x?xf32, #iree_encoding.testing_encoding<>>
8+
}
9+
// CHECK-LABEL: func public @wrap_encoding_op
10+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
11+
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
12+
// CHECK: %[[ENCODING:.+]] = iree_encoding.set_encoding %[[ARG0]]
13+
// CHECK: flow.return %[[ENCODING]]
14+
// CHECK: return %[[DISPATCH]]
15+
16+
// -----
17+
18+
util.func @do_not_wrap_encoding_op(%arg0 : tensor<?x?xf32>)
19+
-> tensor<?x?xf32, #iree_encoding.testing_encoding<>> {
20+
%c0 = arith.constant 0 : index
21+
%c1 = arith.constant 1 : index
22+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
23+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
24+
%0 = flow.dispatch.region -> (tensor<?x?xf32, #iree_encoding.testing_encoding<>>{%d0, %d1}) {
25+
%1 = iree_encoding.set_encoding %arg0 :
26+
tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.testing_encoding<>>
27+
flow.return %1 : tensor<?x?xf32, #iree_encoding.testing_encoding<>>
28+
}
29+
util.return %0 : tensor<?x?xf32, #iree_encoding.testing_encoding<>>
30+
}
31+
// CHECK-LABEL: func public @do_not_wrap_encoding_op
32+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
33+
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
34+
// CHECK: %[[ENCODING:.+]] = iree_encoding.set_encoding %[[ARG0]]
35+
// CHECK: flow.return %[[ENCODING]]
36+
// CHECK: return %[[DISPATCH]]

0 commit comments

Comments
 (0)