Skip to content

Commit 53e15a4

Browse files
authored
[Codegen] Add pass for interpreting transform specs on lowering configs (iree-org#20408)
The existing mechanism for externally authored codegen strategies via transform dialect overrides almost the entire codegen pipeline with the transform dialect strategy. This does not scale as it typically either a) Is no better than a standard pass pipeline, in which case there is no reason to use transform dialect. or b) Is specialized for a single case which railroads future graph level/fusion efforts. A more scalable middle ground is to thread through transform strategies that apply specifically to certain operations and let the surrounding pipeline handle the rest. This patch adds a pass that walks a module and looks up + applies a transform strategy on any annotated op. One concern with transform dialect passes has been dialect registration. The only dialect this pass *requires* is iree_codegen to load the strategy library. Otherwise the current assumption is that this will be used within a TargetBackend's pass pipeline, meaning each target backend adds the dialect that needs to be registered as a part of the surrounding pipeline. I'm not 100% how this will pan out with reproducers but there aren't any good alternatives (registering everything is a no-go).
1 parent 25d9d60 commit 53e15a4

19 files changed

+516
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ iree_compiler_cc_library(
128128
"LinkTuningSpecsPass.cpp",
129129
"LowerExecutableUsingTransformDialect.cpp",
130130
"LowerUKernelsToCalls.cpp",
131+
"LoweringConfigInterpreter.cpp",
131132
"MaterializeEncoding.cpp",
132133
"MaterializeEncodingIntoNop.cpp",
133134
"MaterializeEncodingIntoPadding.cpp",
@@ -252,6 +253,7 @@ iree_compiler_cc_library(
252253
"@llvm-project//mlir:TensorUtils",
253254
"@llvm-project//mlir:TilingInterface",
254255
"@llvm-project//mlir:TransformDialect",
256+
"@llvm-project//mlir:TransformDialectTransforms",
255257
"@llvm-project//mlir:TransformUtils",
256258
"@llvm-project//mlir:Transforms",
257259
"@llvm-project//mlir:ValueBoundsOpInterface",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ iree_cc_library(
121121
"LinkTuningSpecsPass.cpp"
122122
"LowerExecutableUsingTransformDialect.cpp"
123123
"LowerUKernelsToCalls.cpp"
124+
"LoweringConfigInterpreter.cpp"
124125
"MaterializeEncoding.cpp"
125126
"MaterializeEncodingIntoNop.cpp"
126127
"MaterializeEncodingIntoPadding.cpp"
@@ -209,6 +210,7 @@ iree_cc_library(
209210
MLIRTensorUtils
210211
MLIRTilingInterface
211212
MLIRTransformDialect
213+
MLIRTransformDialectTransforms
212214
MLIRTransformUtils
213215
MLIRTransforms
214216
MLIRValueBoundsOpInterface
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include <utility>
8+
#include "iree/compiler/Codegen/Common/Passes.h"
9+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
10+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
11+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
12+
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
13+
#include "mlir/IR/AsmState.h"
14+
#include "mlir/IR/SymbolTable.h"
15+
#include "mlir/IR/Visitors.h"
16+
#include "mlir/Parser/Parser.h"
17+
#include "mlir/Pass/Pass.h"
18+
19+
namespace mlir::iree_compiler {
20+
21+
#define GEN_PASS_DEF_LOWERINGCONFIGINTERPRETERPASS
22+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
23+
24+
constexpr StringLiteral kCodegenExternalSymbolsAttrName =
25+
"iree_codegen_external_symbols";
26+
27+
/// Look up the tuning spec in the given module or any of its parents.
28+
static LogicalResult
29+
getSerializedExternalSymbols(Operation *op,
30+
OwningOpRef<ModuleOp> &symbolsModule) {
31+
auto serializedExternalModule =
32+
op->getAttrOfType<IREE::Util::SerializableAttrInterface>(
33+
kCodegenExternalSymbolsAttrName);
34+
35+
if (!serializedExternalModule) {
36+
return success();
37+
}
38+
39+
SmallVector<char, 0> bytecode;
40+
if (failed(serializedExternalModule.serializeToVector(
41+
op->getLoc(), llvm::endianness::native, bytecode))) {
42+
return op->emitError() << "Failed to read attribute "
43+
<< kCodegenExternalSymbolsAttrName;
44+
}
45+
46+
ParserConfig config(serializedExternalModule.getContext());
47+
symbolsModule = parseSourceString<ModuleOp>(
48+
StringRef(bytecode.data(), bytecode.size()), config);
49+
if (!symbolsModule) {
50+
return op->emitError() << "Failed to parse module in "
51+
<< kCodegenExternalSymbolsAttrName;
52+
}
53+
return success();
54+
}
55+
56+
namespace {
57+
class LoweringConfigInterpreterPass final
58+
: public impl::LoweringConfigInterpreterPassBase<
59+
LoweringConfigInterpreterPass> {
60+
public:
61+
using Base::Base;
62+
void runOnOperation() override {
63+
Operation *rootOp = getOperation();
64+
65+
// Supports both inline strategy IR and externally cached using the
66+
// transform library module mechanism. Inline strategies take precedence
67+
// over external ones in case a symbol matches in both.
68+
auto *symbolTableOp = SymbolTable::getNearestSymbolTable(rootOp);
69+
OwningOpRef<ModuleOp> parsedLibrary;
70+
if (failed(getSerializedExternalSymbols(rootOp, parsedLibrary))) {
71+
return signalPassFailure();
72+
}
73+
74+
// Collect the list of operation + strategy pairs.
75+
SmallVector<std::pair<Operation *, transform::NamedSequenceOp>>
76+
targetStrategyPairs;
77+
WalkResult res = rootOp->walk([&](Operation *op) {
78+
IREE::Codegen::LoweringConfigAttrInterface loweringConfig =
79+
getLoweringConfig(op);
80+
if (!loweringConfig) {
81+
return WalkResult::advance();
82+
}
83+
84+
std::optional<StringRef> maybeSymName =
85+
loweringConfig.getLoweringStrategy();
86+
if (!maybeSymName) {
87+
return WalkResult::advance();
88+
}
89+
90+
auto strategy = dyn_cast_or_null<transform::NamedSequenceOp>(
91+
SymbolTable::lookupSymbolIn(symbolTableOp, *maybeSymName));
92+
if (!strategy && parsedLibrary) {
93+
strategy = dyn_cast_or_null<transform::NamedSequenceOp>(
94+
SymbolTable::lookupSymbolIn(parsedLibrary->getOperation(),
95+
*maybeSymName));
96+
}
97+
98+
// Fail if the strategy cannot be found for some reason. We could pass
99+
// through silently here as it's technically not a hard failure, however
100+
// this creates performance chasms on a predominantly user driven path.
101+
if (!strategy) {
102+
op->emitError("Could not find required strategy ") << *maybeSymName;
103+
return WalkResult::interrupt();
104+
}
105+
106+
targetStrategyPairs.push_back({op, strategy});
107+
return WalkResult::advance();
108+
});
109+
110+
if (res.wasInterrupted()) {
111+
return signalPassFailure();
112+
}
113+
114+
// Apply the lowering strategies in no particular order. It is up to the
115+
// underlying strategies to make sure they don't step on each others toes
116+
// if multiple are present.
117+
transform::TransformOptions options;
118+
for (auto [target, strategy] : targetStrategyPairs) {
119+
if (failed(transform::applyTransformNamedSequence(
120+
target, strategy, /*transformModule=*/nullptr, options))) {
121+
return signalPassFailure();
122+
}
123+
}
124+
125+
// Drop the serialized external symbols if present as we no longer need
126+
// them.
127+
if (rootOp->hasAttr(kCodegenExternalSymbolsAttrName)) {
128+
rootOp->removeAttr(kCodegenExternalSymbolsAttrName);
129+
}
130+
}
131+
};
132+
} // namespace
133+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,16 @@ def LowerExecutableUsingTransformDialectPass :
451451
let summary = "Lower executables using the transform dialect recipe provided in the module.";
452452
}
453453

454+
def LoweringConfigInterpreterPass :
455+
Pass<"iree-codegen-lowering-config-interpreter"> {
456+
let summary = "Pass to apply lowering config annotated strategies.";
457+
let description = [{
458+
This pass runs the transform dialect interpreter and applies the named
459+
sequence transformation specified by lowering configs annotated on
460+
operations.
461+
}];
462+
}
463+
454464
def LowerUKernelOpsToCallsPass :
455465
Pass<"iree-codegen-lower-ukernel-ops-to-calls", "ModuleOp"> {
456466
let summary = "Lower micro-kernel wrapper ops into function calls";

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ iree_lit_test_suite(
6363
"link_tuning_specs.mlir",
6464
"llvmcpu_materialize_encoding.mlir",
6565
"lower_ukernel_to_calls.mlir",
66+
"lowering_config_interpreter.mlir",
6667
"materialize_encoding_into_nop.mlir",
6768
"materialize_encoding_into_padding.mlir",
6869
"materialize_tuning_specs.mlir",
@@ -110,6 +111,7 @@ iree_lit_test_suite(
110111
exclude = [
111112
"batch_matmul_match_spec.mlir",
112113
"convolution_match_spec.mlir",
114+
"external_strategy_spec.mlir",
113115
"reductions_codegen_spec.mlir",
114116
"reductions_match_spec.mlir",
115117
"tuning_spec.mlir",
@@ -122,6 +124,7 @@ iree_lit_test_suite(
122124
data = [
123125
"batch_matmul_match_spec.mlir",
124126
"convolution_match_spec.mlir",
127+
"external_strategy_spec.mlir",
125128
"reductions_codegen_spec.mlir",
126129
"reductions_match_spec.mlir",
127130
"tuning_spec.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ iree_lit_test_suite(
5959
"link_tuning_specs.mlir"
6060
"llvmcpu_materialize_encoding.mlir"
6161
"lower_ukernel_to_calls.mlir"
62+
"lowering_config_interpreter.mlir"
6263
"materialize_encoding_into_nop.mlir"
6364
"materialize_encoding_into_padding.mlir"
6465
"materialize_tuning_specs.mlir"
@@ -107,6 +108,7 @@ iree_lit_test_suite(
107108
DATA
108109
batch_matmul_match_spec.mlir
109110
convolution_match_spec.mlir
111+
external_strategy_spec.mlir
110112
reductions_codegen_spec.mlir
111113
reductions_match_spec.mlir
112114
tuning_spec.mlir
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: iree-opt %s
2+
3+
module @user_spec attributes { transform.with_named_sequence } {
4+
transform.named_sequence @lowering_strategy(%op: !transform.any_op {transform.readonly}) {
5+
transform.print {name = "I am external", skip_regions}
6+
transform.yield
7+
}
8+
transform.named_sequence @import_lowering_strategy(%op: !transform.any_op {transform.readonly}) -> !transform.any_op
9+
attributes { iree_codegen.tuning_spec_entrypoint } {
10+
%syms = transform.util.create_serialized_module {
11+
^bb0(%m: !transform.any_op):
12+
transform.util.import_symbol @lowering_strategy into %m if undefined : (!transform.any_op) -> !transform.any_op
13+
transform.annotate %m "transform.with_named_sequence" : !transform.any_op
14+
} -> !transform.any_param
15+
transform.annotate %op "iree_codegen_external_symbols" = %syms : !transform.any_op, !transform.any_param
16+
transform.yield %op : !transform.any_op
17+
}
18+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-codegen-materialize-tuning-specs, iree-codegen-materialize-user-configs, func.func(iree-codegen-lowering-config-interpreter))" \
2+
// RUN: --iree-codegen-tuning-spec-path=%p/external_strategy_spec.mlir --split-input-file | FileCheck %s
3+
4+
#config = #iree_gpu.lowering_config<{lowering_strategy = "print_me"}>
5+
module attributes { transform.with_named_sequence } {
6+
func.func @single_config(%arg0: i32) -> i32 {
7+
%add = arith.addi %arg0, %arg0 {lowering_config = #config} : i32
8+
return %add : i32
9+
}
10+
11+
transform.named_sequence @print_me(%op: !transform.any_op {transform.readonly}) {
12+
transform.print %op : !transform.any_op
13+
transform.yield
14+
}
15+
}
16+
17+
// CHECK: IR printer:
18+
// CHECK-NEXT: arith.addi
19+
20+
// -----
21+
22+
#config1 = #iree_gpu.lowering_config<{lowering_strategy = "print_one"}>
23+
#config2 = #iree_gpu.lowering_config<{lowering_strategy = "print_two"}>
24+
#config3 = #iree_gpu.lowering_config<{lowering_strategy = "print_three"}>
25+
module attributes { transform.with_named_sequence } {
26+
func.func @multi_config(%arg0: i32) -> i32 {
27+
%add1 = arith.addi %arg0, %arg0 {lowering_config = #config1} : i32
28+
%add2 = arith.addi %add1, %add1 {lowering_config = #config2} : i32
29+
%add3 = arith.addi %add2, %add2 {lowering_config = #config3} : i32
30+
return %add3 : i32
31+
}
32+
33+
transform.named_sequence @print_one(%op: !transform.any_op {transform.readonly}) {
34+
transform.print %op {name = "one"} : !transform.any_op
35+
transform.yield
36+
}
37+
transform.named_sequence @print_two(%op: !transform.any_op {transform.readonly}) {
38+
transform.print %op {name = "two"} : !transform.any_op
39+
transform.yield
40+
}
41+
transform.named_sequence @print_three(%op: !transform.any_op {transform.readonly}) {
42+
transform.print %op {name = "three"} : !transform.any_op
43+
transform.yield
44+
}
45+
}
46+
47+
// CHECK: IR printer:
48+
// CHECK-DAG: one
49+
// CHECK-DAG: two
50+
// CHECK-DAG: three
51+
52+
// -----
53+
54+
#config = #iree_gpu.lowering_config<{lowering_strategy = "lowering_strategy"}>
55+
module {
56+
func.func @external_strategy(%arg0: i32) -> i32 {
57+
%add = arith.addi %arg0, %arg0 {lowering_config = #config} : i32
58+
return %add : i32
59+
}
60+
}
61+
62+
// See ./external_strategy_spec.mlir for the implementation of
63+
// "lowering_strategy" annotated for this test.
64+
//
65+
// CHECK: IR printer: I am external
66+
67+
// -----
68+
69+
#config = #iree_gpu.lowering_config<{lowering_strategy = "lowering_strategy"}>
70+
module attributes { transform.with_named_sequence } {
71+
func.func @override_external_strategy(%arg0: i32) -> i32 {
72+
%add = arith.addi %arg0, %arg0 {lowering_config = #config} : i32
73+
return %add : i32
74+
}
75+
76+
transform.named_sequence @lowering_strategy(%op: !transform.any_op {transform.readonly}) {
77+
transform.print {name = "I am internal"}
78+
transform.yield
79+
}
80+
}
81+
82+
// CHECK: IR printer: I am internal

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,19 @@ def IREECodegen_LoweringConfigAttrInterface :
112112
/*defaultImplementation=*/[{
113113
return ::llvm::SmallVector<OpFoldResult>();
114114
}]
115+
>,
116+
InterfaceMethod<
117+
/*desc=*/[{
118+
Gets the name of the custom lowering strategy to apply to the annotated
119+
operation.
120+
}],
121+
/*retTy=*/"::std::optional<::llvm::StringRef>",
122+
/*methodName=*/"getLoweringStrategy",
123+
/*args=*/(ins),
124+
/*methodBody=*/"",
125+
/*defaultImplementation=*/[{
126+
return std::nullopt;
127+
}]
115128
>
116129
];
117130
}

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,15 @@ bool LoweringConfigAttr::hasWorkgroupTilingLevel() const {
12961296
return !getWorkgroupTileSizes().empty();
12971297
}
12981298

1299+
constexpr StringLiteral kLoweringStrategyName = "lowering_strategy";
1300+
1301+
std::optional<StringRef> LoweringConfigAttr::getLoweringStrategy() const {
1302+
if (auto name = getAttributes().getAs<StringAttr>(kLoweringStrategyName)) {
1303+
return name.strref();
1304+
}
1305+
return std::nullopt;
1306+
}
1307+
12991308
//===----------------------------------------------------------------------===//
13001309
// DerivedThreadConfigAttr
13011310
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)