@@ -38,6 +38,7 @@ limitations under the License.
3838#include " absl/status/status.h"
3939#include " absl/status/statusor.h"
4040#include " absl/strings/str_cat.h"
41+ #include " absl/strings/str_format.h"
4142#include " absl/synchronization/mutex.h"
4243#include " llvm/include/llvm/ADT/SmallVector.h"
4344#include " llvm/include/llvm/Support/CodeGen.h"
@@ -415,6 +416,40 @@ GetKernelCache() {
415416 return std::make_pair (&context_cache, &mutex);
416417}
417418
419+ absl::StatusOr<std::pair<std::string, std::string>> GetHostAndInitFuncNames (
420+ mlir::ModuleOp module_op) {
421+ // We look for two top level C-interface functions:
422+ // - "host" function with symbol name "_mlir_ciface_<foo>"
423+ // - "init" function with symbol name "_mlir_ciface_<foo>_init"
424+ constexpr std::string_view prefix = " _mlir_ciface_" ;
425+ std::vector<std::string> names;
426+ for (mlir::LLVM::LLVMFuncOp llvm_func :
427+ module_op.getOps <mlir::LLVM::LLVMFuncOp>()) {
428+ if (llvm_func.getName ().starts_with (prefix)) {
429+ names.push_back (llvm_func.getName ().str ());
430+ }
431+ }
432+ if (auto size = names.size (); size != 2 ) {
433+ return absl::InternalError (absl::StrFormat (
434+ " Expected to locate 2 symbols with %s prefix in the MLIR module, found "
435+ " %d instead." ,
436+ prefix, size));
437+ }
438+ // _mlir_ciface_<foo>_init now follows _mlir_ciface_<foo>
439+ std::sort (names.begin (), names.end ());
440+
441+ std::string host_func_name = names[0 ];
442+ std::string init_func_name = names[1 ];
443+
444+ if (init_func_name != absl::StrCat (host_func_name, " _init" )) {
445+ return absl::InternalError (absl::StrFormat (
446+ " Expected init function name to equal the concatenation of the host "
447+ " function name and the \" _init\" suffix, instead got "
448+ " init_func_name=%s, host_func_name=%s." ,
449+ init_func_name, host_func_name));
450+ }
451+ return std::make_pair (host_func_name, init_func_name);
452+ }
418453
419454absl::StatusOr<CompiledKernel> CompileAndInit (const char * module ) {
420455 mlir::MLIRContext context (mlir::MLIRContext::Threading::DISABLED);
@@ -430,9 +465,16 @@ absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
430465 return maybe_engine.status ();
431466 }
432467 mlir::ExecutionEngine* execution_engine = maybe_engine->get ();
433- auto main = execution_engine->lookupPacked (" _mlir_ciface_main" );
434- auto init = execution_engine->lookupPacked (" _mlir_ciface_main_init" );
435- if (!init || !main) {
468+
469+ auto host_and_init_func_names = GetHostAndInitFuncNames (*module_op);
470+ if (!host_and_init_func_names.ok ()) {
471+ return host_and_init_func_names.status ();
472+ }
473+ auto [host_name, init_name] = host_and_init_func_names.value ();
474+
475+ auto host = execution_engine->lookupPacked (host_name);
476+ auto init = execution_engine->lookupPacked (init_name);
477+ if (!init || !host) {
436478 return absl::InternalError (" Failed to retrieve kernel function" );
437479 }
438480 void * module_ptr = nullptr ;
@@ -442,7 +484,7 @@ absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
442484 void *** init_args[2 ] = {&module_ptr_ptr, &kernel_ptr_ptr};
443485 reinterpret_cast <MosaicInitFunc*>(*init)(init_args);
444486 return CompiledKernel (std::move (*maybe_engine), kernel_ptr,
445- reinterpret_cast <MosaicHostFunc*>(*main ));
487+ reinterpret_cast <MosaicHostFunc*>(*host ));
446488}
447489
448490// Each compiled kernel has a unique init func, and each kernel is used from
0 commit comments