Skip to content

Commit cb467bf

Browse files
authored
[DT] Add a flag to not hoist encodings when the source is ConstExpr. (#20526)
The revision adds an option to not hoist encodings when the source is a global.load op or has ConstantLike trait. It is a workaround for huge memory allocation because the ops are still in dispatch region. I.e., it stops hoisting globals to initializer. Signed-off-by: hanhanW <[email protected]>
1 parent 2562998 commit cb467bf

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
1111
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
1212
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
13+
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
1314
#include "iree/compiler/DispatchCreation/Passes.h"
1415
#include "llvm/ADT/STLExtras.h"
1516
#include "llvm/Support/Debug.h"
@@ -188,9 +189,16 @@ void HoistEncodingOpsPass::runOnOperation() {
188189

189190
SmallVector<IREE::Encoding::SetEncodingOp> candidates;
190191
funcOp->walk([&](IREE::Encoding::SetEncodingOp setEncodingOp) {
191-
if (setEncodingOp->getParentOfType<IREE::Flow::DispatchRegionOp>()) {
192-
candidates.push_back(setEncodingOp);
192+
if (!setEncodingOp->getParentOfType<IREE::Flow::DispatchRegionOp>()) {
193+
return;
193194
}
195+
Operation *src = setEncodingOp.getSource().getDefiningOp();
196+
if (!hoistEncodingsForConstExpr && src &&
197+
(isa<IREE::Util::GlobalLoadOp>(src) ||
198+
src->hasTrait<OpTrait::ConstantLike>())) {
199+
return;
200+
}
201+
candidates.push_back(setEncodingOp);
194202
});
195203
IRRewriter rewriter(ctx);
196204
for (auto setEncodingOp : candidates) {

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ static llvm::cl::opt<bool> clEnableDataTiling(
8080
"path, --iree-opt-data-tiling=false must be set as wells"),
8181
llvm::cl::init(false));
8282

83+
static llvm::cl::opt<bool> clHoistEncodingsForConstExpr(
84+
"iree-dispatch-creation-hoist-encodings-for-constexpr",
85+
llvm::cl::desc("Enable the hoisting of encoding ops when the source is "
86+
"from globals. To use this path, "
87+
"--iree-opt-data-tiling=false must be set as wells"),
88+
llvm::cl::init(true));
89+
8390
namespace mlir::iree_compiler::DispatchCreation {
8491

8592
//===----------------------------------------------------------------------===//
@@ -252,9 +259,10 @@ addDispatchRegionCreationPasses(OpPassManager &passManager,
252259
// op, so hoist them out of their current dispatch regions. Also, bubble
253260
// SetEncodingOps through special operations like bit-extending ops and
254261
// broadcasting ops.
255-
.addPass(DispatchCreation::createHoistEncodingOpsPass)
256-
// After SetEncodingOps are hoisted, try to fuse them with their
257-
// producer dispatches to try to hide packing costs.
262+
.addPass([&]() {
263+
return DispatchCreation::createHoistEncodingOpsPass(
264+
HoistEncodingOpsPassOptions{clHoistEncodingsForConstExpr});
265+
})
258266
.addPass(
259267
DispatchCreation::createFuseEncodingOpsIntoDispatchRegionsPass);
260268
}

compiler/src/iree/compiler/DispatchCreation/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ def FuseEncodingOpsIntoDispatchRegionsPass :
286286
def HoistEncodingOpsPass :
287287
InterfacePass<"iree-dispatch-creation-hoist-encoding-ops", "mlir::FunctionOpInterface"> {
288288
let summary = "Hoists tensor encoding ops out of flow dispatch regions.";
289+
let options = [
290+
Option<"hoistEncodingsForConstExpr", "hoist-encodings-for-constexpr", "bool", /*default=*/"true",
291+
"Enable the hoisting when the source is a ConstExpr.">,
292+
];
289293
let dependentDialects = [
290294
"mlir::linalg::LinalgDialect",
291295
"IREE::Flow::FlowDialect",

compiler/src/iree/compiler/DispatchCreation/test/hoist_encoding_ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-hoist-encoding-ops))" --split-input-file %s | FileCheck %s
2+
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-hoist-encoding-ops{hoist-encodings-for-constexpr=false}))" --split-input-file %s | FileCheck %s --check-prefix=NO-CONST
23

34
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
45
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -247,3 +248,8 @@ util.func public @hoist_encoding_only() -> tensor<640x320xf32> {
247248
// CHECK: %[[CST:.+]] = arith.constant
248249
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[CST]]
249250
// CHECK: flow.dispatch.region
251+
252+
// NO-CONST-LABEL: util.func public @hoist_encoding_only(
253+
// NO-CONST: %[[CST:.+]] = arith.constant
254+
// NO-CONST: flow.dispatch.region
255+
// NO-CONST: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[CST]]

0 commit comments

Comments
 (0)