Skip to content

Commit d34bf98

Browse files
kuhargiacs-epic
authored andcommitted
[Codegen] Load transform library only once in MaterializeUserConfigs (iree-org#19313)
Hoist the library loading logic out of the loop that configures functions. This is in preparation for adding tuning spec loading from a new module attr. Issue: iree-org#19214 Signed-off-by: Giacomo Serafini <[email protected]>
1 parent cc61a73 commit d34bf98

File tree

1 file changed

+61
-51
lines changed

1 file changed

+61
-51
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#include "iree/compiler/Codegen/Common/UserConfig.h"
99
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1010
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
11+
#include "llvm/ADT/StringRef.h"
12+
#include "llvm/Support/LogicalResult.h"
1113
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
14+
#include "mlir/IR/BuiltinOps.h"
15+
#include "mlir/IR/MLIRContext.h"
1216

1317
#define DEBUG_TYPE "iree-codegen-materialize-user-configs"
1418
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -61,15 +65,64 @@ runTransformConfigurationStrategy(Operation *payloadRoot,
6165
return StrategyRunResult::Success;
6266
}
6367

68+
struct TransformLibraryWithEntrypoint {
69+
ModuleOp transformLibrary;
70+
std::string entrypointName;
71+
};
72+
73+
static FailureOr<TransformLibraryWithEntrypoint>
74+
getTransformLibraryFromPath(ModuleOp compiledModule, StringRef path) {
75+
SmallVector<StringRef, 2> parts;
76+
llvm::SplitString(path, parts, "@");
77+
if (parts.empty()) {
78+
return failure();
79+
}
80+
if (parts.size() > 2) {
81+
return compiledModule.emitError()
82+
<< "Invalid transform library path and sequence name " << path;
83+
}
84+
StringRef libraryFileName = parts[0];
85+
StringRef entrySequenceName = kKernelConfigSpecName;
86+
if (parts.size() == 2) {
87+
entrySequenceName = parts[1];
88+
}
89+
90+
// Validate both the file name and the spec name.
91+
if (libraryFileName.empty()) {
92+
return compiledModule.emitError() << "Cannot specify an empty library path";
93+
}
94+
if (entrySequenceName.empty()) {
95+
return compiledModule.emitError()
96+
<< "Cannot specify an empty sequence name";
97+
}
98+
99+
MLIRContext *ctx = compiledModule->getContext();
100+
auto dialect = ctx->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
101+
auto maybeTransformLibrary =
102+
dialect->getOrLoadTransformLibraryModule(libraryFileName.str());
103+
if (failed(maybeTransformLibrary)) {
104+
return compiledModule.emitError()
105+
<< "Failed to load transform library module: " << libraryFileName;
106+
}
107+
LDBG("--found transform library " << libraryFileName << "@"
108+
<< entrySequenceName);
109+
return TransformLibraryWithEntrypoint{*maybeTransformLibrary,
110+
entrySequenceName.str()};
111+
}
112+
64113
struct MaterializeUserConfigsPass final
65114
: impl::MaterializeUserConfigsPassBase<MaterializeUserConfigsPass> {
66115
void getDependentDialects(DialectRegistry &registry) const override {
67116
registerTransformDialectTranslationDependentDialects(registry);
68117
}
69118

70119
void runOnOperation() override {
71-
auto moduleOp = getOperation();
72-
MLIRContext *context = &getContext();
120+
ModuleOp moduleOp = getOperation();
121+
122+
FailureOr<TransformLibraryWithEntrypoint> userTransformLibrary =
123+
getTransformLibraryFromPath(moduleOp,
124+
clCodegenTransformDialectLibraryFileName);
125+
73126
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
74127

75128
// Parse the file path and kernel config strategy from flags. There are
@@ -84,54 +137,11 @@ struct MaterializeUserConfigsPass final
84137
// "translation_info" =
85138
// #iree_codegen.translation_info<pipeline = None>
86139
// ```
87-
SmallVector<StringRef, 2> parts;
88-
llvm::SplitString(
89-
llvm::StringRef(clCodegenTransformDialectLibraryFileName), parts,
90-
"@");
91-
if (parts.size() > 2) {
92-
funcOp.emitError()
93-
<< "Invalid transform library path and sequence name "
94-
<< clCodegenTransformDialectLibraryFileName;
95-
return signalPassFailure();
96-
}
97-
bool hasTransformLibrary = !parts.empty();
98-
99-
std::string libraryFileName;
100-
if (hasTransformLibrary) {
101-
if (parts[0].empty()) {
102-
funcOp.emitError() << "Cannot specify an empty library path";
103-
return signalPassFailure();
104-
}
105-
libraryFileName = parts[0];
106-
}
107-
108-
StringRef entrySequenceName = kKernelConfigSpecName;
109-
// Check if the user specified a custom entry point name.
110-
if (parts.size() == 2) {
111-
if (parts[1].empty()) {
112-
funcOp.emitError() << "Cannot specify an empty sequence name";
113-
return signalPassFailure();
114-
}
115-
entrySequenceName = parts[1];
116-
}
117-
118140
LDBG("MaterializeUserConfigsPass on function: " << funcOp);
119-
std::optional<ModuleOp> transformLibrary = std::nullopt;
120-
if (hasTransformLibrary) {
121-
auto dialect =
122-
context->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
123-
auto maybeTransformLibrary =
124-
dialect->getOrLoadTransformLibraryModule(libraryFileName);
125-
if (failed(maybeTransformLibrary)) {
126-
funcOp.emitError()
127-
<< "failed to load transform library module: " << libraryFileName;
128-
return signalPassFailure();
129-
}
130-
transformLibrary = *maybeTransformLibrary;
131-
LDBG("--found transform library @" << libraryFileName);
132-
141+
if (succeeded(userTransformLibrary)) {
142+
StringRef entrySequenceName = userTransformLibrary->entrypointName;
133143
auto runResult = runTransformConfigurationStrategy(
134-
funcOp, entrySequenceName, *transformLibrary);
144+
funcOp, entrySequenceName, userTransformLibrary->transformLibrary);
135145
if (runResult == StrategyRunResult::NotFound) {
136146
funcOp.emitError() << "transform kernel config strategy `"
137147
<< entrySequenceName << " not found";
@@ -186,9 +196,9 @@ struct MaterializeUserConfigsPass final
186196
/// If we have a symbol, verify the existence of the symbol within the
187197
/// transform library.
188198
StringRef entryPoint = strategyName->getLeafReference();
189-
if (!transformLibrary || !(*transformLibrary) ||
190-
!transform::detail::findTransformEntryPoint(funcOp, *transformLibrary,
191-
entryPoint)) {
199+
if (failed(userTransformLibrary) ||
200+
!transform::detail::findTransformEntryPoint(
201+
funcOp, userTransformLibrary->transformLibrary, entryPoint)) {
192202
funcOp.emitOpError("failed to find transform strategy symbol");
193203
}
194204
}

0 commit comments

Comments
 (0)