Skip to content

Commit 1a7b51d

Browse files
authored
[Codegen][Tuner] Add pass to link tuning specs (#19281)
This pass is meant for combining multiple tuning specs (e.g., a user-provided one and a default one). We expect the input module to have nested sub-modules with named sequences marked with the `iree_codegen.tuning_spec_entrypoint` unit attributes. The pass collects all such tuning specs and introduce a new named sequence that includes all the other tuning spec entry points. The order of inclusion is the same as the in which these nested tuning specs appear in the IR. Issue: #19214
1 parent 3129fa9 commit 1a7b51d

File tree

9 files changed

+279
-5
lines changed

9 files changed

+279
-5
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ iree_compiler_cc_library(
122122
"IREEExpandStridedMetadata.cpp",
123123
"IREELoopInvariantCodeMotion.cpp",
124124
"InstrumentMemoryAccesses.cpp",
125+
"LinkTuningSpecsPass.cpp",
125126
"LowerExecutableUsingTransformDialect.cpp",
126127
"LowerUKernelsToCalls.cpp",
127128
"MaterializeEncodingIntoNop.cpp",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ iree_cc_library(
114114
"IREEExpandStridedMetadata.cpp"
115115
"IREELoopInvariantCodeMotion.cpp"
116116
"InstrumentMemoryAccesses.cpp"
117+
"LinkTuningSpecsPass.cpp"
117118
"LowerExecutableUsingTransformDialect.cpp"
118119
"LowerUKernelsToCalls.cpp"
119120
"MaterializeEncodingIntoNop.cpp"
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include <cassert>
8+
#include "iree/compiler/Codegen/Common/Passes.h"
9+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
10+
#include "llvm/ADT/STLExtras.h"
11+
#include "llvm/ADT/SmallVector.h"
12+
#include "llvm/ADT/SmallVectorExtras.h"
13+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
14+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
15+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
16+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
17+
#include "mlir/IR/Builders.h"
18+
#include "mlir/IR/BuiltinAttributes.h"
19+
#include "mlir/IR/BuiltinOps.h"
20+
#include "mlir/IR/Location.h"
21+
22+
#define DEBUG_TYPE "iree-codegen-link-tuning-specs"
23+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
24+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
25+
26+
namespace mlir::iree_compiler {
27+
28+
#define GEN_PASS_DEF_LINKTUNINGSPECSPASS
29+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
30+
31+
namespace {
32+
33+
using mlir::transform::NamedSequenceOp;
34+
35+
static SmallVector<ModuleOp>
36+
findNestedModulesWithNamedSequences(ModuleOp module) {
37+
Block *body = module.getBody();
38+
return llvm::to_vector(
39+
llvm::make_filter_range(body->getOps<ModuleOp>(), [](ModuleOp op) {
40+
return op.getSymName().has_value() &&
41+
op->hasAttr(
42+
transform::TransformDialect::kWithNamedSequenceAttrName);
43+
}));
44+
}
45+
46+
static SmallVector<NamedSequenceOp> findTuningSpecs(ModuleOp module) {
47+
Block *body = module.getBody();
48+
return llvm::to_vector(llvm::make_filter_range(
49+
body->getOps<NamedSequenceOp>(),
50+
[](NamedSequenceOp op) { return op->hasAttr(kTuningSpecAttrName); }));
51+
}
52+
53+
static LogicalResult validateTuningSpec(NamedSequenceOp op) {
54+
if (!op.getResultTypes().empty()) {
55+
op->emitWarning() << "Tuning spec expected to have no results";
56+
return failure();
57+
}
58+
59+
ArrayRef<Type> argTypes = op.getArgumentTypes();
60+
if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
61+
op->emitWarning() << "Tuning spec expected to have one argument of type "
62+
"'!transform.any_op'";
63+
return failure();
64+
}
65+
66+
if (!op.getArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName)) {
67+
op->emitWarning() << "Tuning spec expected to have one readonly argument";
68+
return failure();
69+
}
70+
71+
return success();
72+
}
73+
74+
static NamedSequenceOp
75+
emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
76+
OpBuilder builder(module->getContext());
77+
builder.setInsertionPointToEnd(module.getBody());
78+
79+
Location loc = builder.getFusedLoc(llvm::map_to_vector(
80+
specsToLink, [](NamedSequenceOp op) { return op->getLoc(); }));
81+
FunctionType specType = builder.getFunctionType(
82+
TypeRange{builder.getType<transform::AnyOpType>()}, TypeRange{});
83+
auto newSpec = builder.create<NamedSequenceOp>(
84+
loc, kKernelConfigSpecName, TypeAttr::get(specType),
85+
/*sym_visibility=*/StringAttr{},
86+
/*arg_attrs=*/ArrayAttr{},
87+
/*res_attrs*/ ArrayAttr{});
88+
newSpec.setArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName,
89+
builder.getUnitAttr());
90+
newSpec->setAttr(kTuningSpecAttrName, builder.getUnitAttr());
91+
92+
Region &region = newSpec.getRegion();
93+
Block *body = builder.createBlock(&region, region.begin(),
94+
newSpec.getArgumentTypes(), loc);
95+
builder.setInsertionPointToStart(body);
96+
97+
// Emit one `transform.include` op per child tuning spec. In the future,
98+
// we may want to switch to a custom transform op for this to perform
99+
// 'short-circuring' and apply at most one tuning spec.
100+
Value operand = body->getArgument(0);
101+
for (NamedSequenceOp spec : specsToLink) {
102+
ModuleOp parentModule = spec->getParentOfType<ModuleOp>();
103+
assert(parentModule);
104+
StringAttr parentSymbol = parentModule.getSymNameAttr();
105+
assert(parentSymbol);
106+
auto symbol = SymbolRefAttr::get(
107+
parentSymbol, FlatSymbolRefAttr::get(spec.getSymNameAttr()));
108+
109+
// Surpress silenceable errors so that failures to match in child tuning
110+
// specs can be ignored.
111+
builder.create<transform::IncludeOp>(
112+
loc, TypeRange{}, symbol, transform::FailurePropagationMode::Suppress,
113+
operand);
114+
}
115+
116+
builder.create<transform::YieldOp>(loc);
117+
return newSpec;
118+
}
119+
120+
struct LinkTuningSpecsPass final
121+
: impl::LinkTuningSpecsPassBase<LinkTuningSpecsPass> {
122+
void getDependentDialects(DialectRegistry &registry) const override {
123+
registerTransformDialectTranslationDependentDialects(registry);
124+
}
125+
126+
void runOnOperation() override {
127+
ModuleOp module = getOperation();
128+
SmallVector<NamedSequenceOp> tuningSpecs;
129+
130+
for (ModuleOp nested : findNestedModulesWithNamedSequences(module)) {
131+
llvm::append_range(tuningSpecs, findTuningSpecs(nested));
132+
}
133+
134+
for (NamedSequenceOp spec : tuningSpecs) {
135+
LDBG("Found tuning spec: " << spec.getSymName());
136+
if (failed(validateTuningSpec(spec))) {
137+
return signalPassFailure();
138+
}
139+
}
140+
141+
if (tuningSpecs.empty()) {
142+
LDBG("No tuning specs found, exiting without linking");
143+
return;
144+
}
145+
146+
emitLinkedTuningSpec(module, tuningSpecs);
147+
}
148+
};
149+
150+
} // namespace
151+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ llvm::cl::opt<std::string> clCodegenTransformDialectLibraryFileName(
3131

3232
namespace {
3333

34-
constexpr StringLiteral kTranslationInfoAttrName = "translation_info";
35-
constexpr StringLiteral kDefaultTransformSequenceName = "__kernel_config";
34+
constexpr StringLiteral kTranslationInfoAttrName =
35+
IREE::Codegen::TranslationInfoAttr::name;
3636

3737
enum StrategyRunResult {
3838
Success = 0,
@@ -105,7 +105,7 @@ struct MaterializeUserConfigsPass final
105105
libraryFileName = parts[0];
106106
}
107107

108-
StringRef entrySequenceName = kDefaultTransformSequenceName;
108+
StringRef entrySequenceName = kKernelConfigSpecName;
109109
// Check if the user specified a custom entry point name.
110110
if (parts.size() == 2) {
111111
if (parts[1].empty()) {

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,21 @@ def InstrumentMemoryAccessesPass :
406406
let summary = "Instruments memory reads and writes for address tracking when dispatch instrumentation is enabled.";
407407
}
408408

409+
def LinkTuningSpecsPass : Pass<"iree-codegen-link-tuning-specs", "ModuleOp"> {
410+
let summary =
411+
"Link nested transform dialect tuning specs named sequences into a single entry point";
412+
let description = [{
413+
Given a module with multiple nested tuning specs, introduce a new named sequence
414+
that includes all the other tuning spec entry points. The order of inclusion is the same
415+
as the in which these nested tuning specs appear in the IR.
416+
417+
A tuning spec entry point is a `transform.named_sequence` op annotated with the
418+
`iree_codegen.tuning_spec` unit attribute. We require it to perform in-place op
419+
modification and not consume the handle.
420+
}];
421+
let dependentDialects = ["transform::TransformDialect"];
422+
}
423+
409424
def LowerExecutableUsingTransformDialectPass :
410425
Pass<"iree-codegen-lower-executable-using-transform-dialect", "ModuleOp"> {
411426
let summary = "Lower executables using the transform dialect recipe provided in the module.";

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ iree_lit_test_suite(
5252
"iree_comprehensive_bufferize.mlir",
5353
"iree_expand_strided_metadata.mlir",
5454
"iree_loop_invariant_code_motion.mlir",
55+
"link_tuning_specs.mlir",
5556
"lower_ukernel_to_calls.mlir",
5657
"materialize_encoding_into_nop.mlir",
5758
"materialize_user_configs.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ iree_lit_test_suite(
4848
"iree_comprehensive_bufferize.mlir"
4949
"iree_expand_strided_metadata.mlir"
5050
"iree_loop_invariant_code_motion.mlir"
51+
"link_tuning_specs.mlir"
5152
"lower_ukernel_to_calls.mlir"
5253
"materialize_encoding_into_nop.mlir"
5354
"materialize_user_configs.mlir"
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// RUN: iree-opt %s --no-implicit-module --iree-codegen-link-tuning-specs --split-input-file \
2+
// RUN: | FileCheck %s
3+
4+
// CHECK-LABEL: module @td_module_0
5+
//
6+
// CHECK: transform.named_sequence @outer_spec
7+
//
8+
// CHECK: transform.named_sequence @__kernel_config
9+
// CHECK-SAME: (%arg0: !transform.any_op {transform.readonly})
10+
// CHECK-SAME: attributes {iree_codegen.tuning_spec_entrypoint}
11+
// CHECK: transform.include @foo_module::@foo failures(suppress)
12+
// CHECK-NEXT: transform.include @bar_module::@bar failures(suppress)
13+
// CHECK-NEXT: transform.include @baz_module::@baz failures(suppress)
14+
// CHECK-NEXT: transform.yield
15+
16+
module @td_module_0 attributes { transform.with_named_sequence } {
17+
module @foo_module attributes { transform.with_named_sequence } {
18+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> ()
19+
attributes { iree_codegen.tuning_spec_entrypoint } {
20+
transform.print {name = "Foo", skip_regions}
21+
transform.yield
22+
}
23+
}
24+
25+
module @bar_module attributes { transform.with_named_sequence } {
26+
transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> ()
27+
attributes { iree_codegen.tuning_spec_entrypoint } {
28+
transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
29+
transform.print {name = "Bar", skip_regions}
30+
transform.yield
31+
}
32+
}
33+
34+
module @baz_module attributes { transform.with_named_sequence } {
35+
transform.named_sequence @baz(%arg0: !transform.any_op {transform.readonly}) -> ()
36+
attributes { iree_codegen.tuning_spec_entrypoint } {
37+
transform.print {name = "Baz", skip_regions}
38+
transform.yield
39+
}
40+
}
41+
42+
transform.named_sequence @outer_spec(%module: !transform.any_op {transform.readonly}) -> ()
43+
attributes { iree_codegen.tuning_spec_entrypoint } {
44+
transform.yield
45+
}
46+
}
47+
48+
49+
// -----
50+
51+
// Here, `foo` shouldn't be included because it's not marked with `tuning_spec_entrypoint`.
52+
53+
// CHECK-LABEL: module @td_module_1
54+
// CHECK: @foo_module
55+
// CHECK: @__kernel_config
56+
// CHECK-NOT transform.include @foo_module::@foo failures(suppress) (%arg0) : (!transform.any_op) -> ()
57+
// CHECK: transform.include @foo_module::@bar failures(suppress) (%arg0) : (!transform.any_op) -> ()
58+
// CHECK-NEXT: transform.yield
59+
60+
module @td_module_1 attributes { transform.with_named_sequence } {
61+
module @foo_module attributes { transform.with_named_sequence } {
62+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
63+
transform.yield
64+
}
65+
transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> ()
66+
attributes { iree_codegen.tuning_spec_entrypoint } {
67+
transform.yield
68+
}
69+
func.func @baz(%arg0: i32) -> () {
70+
return
71+
}
72+
}
73+
}
74+
75+
76+
// -----
77+
78+
// Make sure we do not crash on modules with no tuning specs.
79+
80+
// CHECK-LABEL: module @td_module_2
81+
// CHECK-NOT: @__kernel_config
82+
module @td_module_2 attributes { transform.with_named_sequence } {}
83+
84+
// -----
85+
86+
// Make sure we do not crash on unnamed nested modules.
87+
88+
// CHECK-LABEL: module @td_module_3
89+
// CHECK: transform.named_sequence @foo
90+
// CHECK-NOT: @__kernel_config
91+
92+
module @td_module_3 attributes { transform.with_named_sequence } {
93+
module attributes { transform.with_named_sequence } {
94+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> ()
95+
attributes { iree_codegen.tuning_spec_entrypoint } {
96+
transform.yield
97+
}
98+
}
99+
}

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ using ScalableTileFlagsListTypeRef = ArrayRef<SmallVector<bool>>;
3434
// clang-format on
3535

3636
namespace mlir::iree_compiler {
37+
//===----------------------------------------------------------------------===//
38+
// Constant names.
39+
//===----------------------------------------------------------------------===//
40+
constexpr StringLiteral kConfigAttrName = "lowering_config";
41+
constexpr StringLiteral kTuningSpecAttrName =
42+
"iree_codegen.tuning_spec_entrypoint";
43+
constexpr StringLiteral kKernelConfigSpecName = "__kernel_config";
3744

3845
//===----------------------------------------------------------------------===//
3946
// Helpers for getting/setting iree_codegen.translation_info attribute on the
@@ -66,8 +73,6 @@ void eraseTranslationInfo(mlir::FunctionOpInterface funcOp);
6673
// operations.
6774
//===----------------------------------------------------------------------===//
6875

69-
static const char kConfigAttrName[] = "lowering_config";
70-
7176
/// Returns the lowering configuration set for an operation. Returns `nullptr`
7277
/// if no value is set. It expects that the attribute is stored using the
7378
/// identifier `lowering_config`.

0 commit comments

Comments
 (0)