88#include " iree/compiler/Codegen/Common/UserConfig.h"
99#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1010#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
11+ #include " llvm/ADT/StringRef.h"
12+ #include " llvm/Support/LogicalResult.h"
1113#include " mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
14+ #include " mlir/IR/BuiltinOps.h"
15+ #include " mlir/IR/MLIRContext.h"
1216
1317#define DEBUG_TYPE " iree-codegen-materialize-user-configs"
1418#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
@@ -61,15 +65,64 @@ runTransformConfigurationStrategy(Operation *payloadRoot,
6165 return StrategyRunResult::Success;
6266}
6367
68+ struct TransformLibraryWithEntrypoint {
69+ ModuleOp transformLibrary;
70+ std::string entrypointName;
71+ };
72+
73+ static FailureOr<TransformLibraryWithEntrypoint>
74+ getTransformLibraryFromPath (ModuleOp compiledModule, StringRef path) {
75+ SmallVector<StringRef, 2 > parts;
76+ llvm::SplitString (path, parts, " @" );
77+ if (parts.empty ()) {
78+ return failure ();
79+ }
80+ if (parts.size () > 2 ) {
81+ return compiledModule.emitError ()
82+ << " Invalid transform library path and sequence name " << path;
83+ }
84+ StringRef libraryFileName = parts[0 ];
85+ StringRef entrySequenceName = kKernelConfigSpecName ;
86+ if (parts.size () == 2 ) {
87+ entrySequenceName = parts[1 ];
88+ }
89+
90+ // Validate both the file name and the spec name.
91+ if (libraryFileName.empty ()) {
92+ return compiledModule.emitError () << " Cannot specify an empty library path" ;
93+ }
94+ if (entrySequenceName.empty ()) {
95+ return compiledModule.emitError ()
96+ << " Cannot specify an empty sequence name" ;
97+ }
98+
99+ MLIRContext *ctx = compiledModule->getContext ();
100+ auto dialect = ctx->getOrLoadDialect <IREE::Codegen::IREECodegenDialect>();
101+ auto maybeTransformLibrary =
102+ dialect->getOrLoadTransformLibraryModule (libraryFileName.str ());
103+ if (failed (maybeTransformLibrary)) {
104+ return compiledModule.emitError ()
105+ << " Failed to load transform library module: " << libraryFileName;
106+ }
107+ LDBG (" --found transform library " << libraryFileName << " @"
108+ << entrySequenceName);
109+ return TransformLibraryWithEntrypoint{*maybeTransformLibrary,
110+ entrySequenceName.str ()};
111+ }
112+
64113struct MaterializeUserConfigsPass final
65114 : impl::MaterializeUserConfigsPassBase<MaterializeUserConfigsPass> {
66115 void getDependentDialects (DialectRegistry ®istry) const override {
67116 registerTransformDialectTranslationDependentDialects (registry);
68117 }
69118
70119 void runOnOperation () override {
71- auto moduleOp = getOperation ();
72- MLIRContext *context = &getContext ();
120+ ModuleOp moduleOp = getOperation ();
121+
122+ FailureOr<TransformLibraryWithEntrypoint> userTransformLibrary =
123+ getTransformLibraryFromPath (moduleOp,
124+ clCodegenTransformDialectLibraryFileName);
125+
73126 for (auto funcOp : moduleOp.getOps <FunctionOpInterface>()) {
74127
75128 // Parse the file path and kernel config strategy from flags. There are
@@ -84,54 +137,11 @@ struct MaterializeUserConfigsPass final
84137 // "translation_info" =
85138 // #iree_codegen.translation_info<pipeline = None>
86139 // ```
87- SmallVector<StringRef, 2 > parts;
88- llvm::SplitString (
89- llvm::StringRef (clCodegenTransformDialectLibraryFileName), parts,
90- " @" );
91- if (parts.size () > 2 ) {
92- funcOp.emitError ()
93- << " Invalid transform library path and sequence name "
94- << clCodegenTransformDialectLibraryFileName;
95- return signalPassFailure ();
96- }
97- bool hasTransformLibrary = !parts.empty ();
98-
99- std::string libraryFileName;
100- if (hasTransformLibrary) {
101- if (parts[0 ].empty ()) {
102- funcOp.emitError () << " Cannot specify an empty library path" ;
103- return signalPassFailure ();
104- }
105- libraryFileName = parts[0 ];
106- }
107-
108- StringRef entrySequenceName = kKernelConfigSpecName ;
109- // Check if the user specified a custom entry point name.
110- if (parts.size () == 2 ) {
111- if (parts[1 ].empty ()) {
112- funcOp.emitError () << " Cannot specify an empty sequence name" ;
113- return signalPassFailure ();
114- }
115- entrySequenceName = parts[1 ];
116- }
117-
118140 LDBG (" MaterializeUserConfigsPass on function: " << funcOp);
119- std::optional<ModuleOp> transformLibrary = std::nullopt ;
120- if (hasTransformLibrary) {
121- auto dialect =
122- context->getOrLoadDialect <IREE::Codegen::IREECodegenDialect>();
123- auto maybeTransformLibrary =
124- dialect->getOrLoadTransformLibraryModule (libraryFileName);
125- if (failed (maybeTransformLibrary)) {
126- funcOp.emitError ()
127- << " failed to load transform library module: " << libraryFileName;
128- return signalPassFailure ();
129- }
130- transformLibrary = *maybeTransformLibrary;
131- LDBG (" --found transform library @" << libraryFileName);
132-
141+ if (succeeded (userTransformLibrary)) {
142+ StringRef entrySequenceName = userTransformLibrary->entrypointName ;
133143 auto runResult = runTransformConfigurationStrategy (
134- funcOp, entrySequenceName, * transformLibrary);
144+ funcOp, entrySequenceName, userTransformLibrary-> transformLibrary );
135145 if (runResult == StrategyRunResult::NotFound) {
136146 funcOp.emitError () << " transform kernel config strategy `"
137147 << entrySequenceName << " not found" ;
@@ -186,9 +196,9 @@ struct MaterializeUserConfigsPass final
186196 // / If we have a symbol, verify the existence of the symbol within the
187197 // / transform library.
188198 StringRef entryPoint = strategyName->getLeafReference ();
189- if (!transformLibrary || !(*transformLibrary ) ||
190- !transform::detail::findTransformEntryPoint (funcOp, *transformLibrary,
191- entryPoint)) {
199+ if (failed (userTransformLibrary ) ||
200+ !transform::detail::findTransformEntryPoint (
201+ funcOp, userTransformLibrary-> transformLibrary , entryPoint)) {
192202 funcOp.emitOpError (" failed to find transform strategy symbol" );
193203 }
194204 }
0 commit comments