@@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
470470 context.loadDialect <mlir::stablehlo::StablehloDialect>();
471471 context.loadDialect <mlir::chlo::ChloDialect>();
472472}
473+
474+ #include " mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
475+ #include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
476+ #include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
473477extern " C" void InitializeRegistryAndPasses (MlirDialectRegistry creg) {
474478 mlir::DialectRegistry ®istry = *unwrap (creg);
475479
@@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
513517 mlir::affine::registerAffinePasses ();
514518 mlir::registerReconcileUnrealizedCasts ();
515519
520+ mlir::registerLLVMDialectImport (registry);
521+ mlir::registerNVVMDialectImport (registry);
522+
523+ mlir::LLVM::registerInlinerInterface (registry);
524+
516525/*
517526 registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
518527 LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
540549 mlir::enzyme::registerEnzymeJaxTransformExtension (registry);
541550}
542551
552+
553+ // / Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
554+ // / suffix in `lastUsedID`.
555+ static mlir::StringAttr renameSymbol (llvm::StringRef oldSymName,
556+ unsigned &lastUsedID,
557+ mlir::ModuleOp source,
558+ mlir::ModuleOp target) {
559+ using namespace llvm ;
560+ using namespace mlir ;
561+ SmallString<64 > newSymName (oldSymName);
562+ newSymName.push_back (' _' );
563+ while (true ) {
564+ auto possible = newSymName + Twine (++lastUsedID);
565+ if (!SymbolTable::lookupSymbolIn (source, possible.str ()) && !SymbolTable::lookupSymbolIn (target, possible.str ())) {
566+ return StringAttr::get (target.getContext (), possible);
567+ }
568+ }
569+ }
570+
571+
572+ // / Checks if a symbol with the same name as `op` already exists in `source`.
573+ // / If so, renames `op` and updates all its references in `target`.
574+ static mlir::LogicalResult
575+ updateSymbolAndAllUses (mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
576+ unsigned &lastUsedID) {
577+ using namespace llvm ;
578+ using namespace mlir ;
579+
580+ auto opName = op.getName ().str ();
581+
582+ if (!SymbolTable::lookupSymbolIn (target, opName)) {
583+ return success ();
584+ }
585+
586+ StringAttr newSymName =
587+ renameSymbol (opName, lastUsedID, source, target);
588+
589+ if (failed (SymbolTable::replaceAllSymbolUses (op, newSymName, source)))
590+ return op.emitError (" unable to update all symbol uses for " )
591+ << opName << " to " << newSymName;
592+
593+ SymbolTable::setSymbolName (op, newSymName);
594+ return success ();
595+ }
596+
597+ extern " C" MlirOperation LinkInModule (MlirModule prevModC, MlirModule newModC, const char * entryfn) {
598+ auto prevMod = cast<ModuleOp>(*unwrap (prevModC));
599+ auto newMod = cast<ModuleOp>(*unwrap (newModC));
600+
601+ Operation* entryFn = nullptr ;
602+
603+ unsigned lastUsedID = 0 ;
604+
605+ for (auto &op : *newMod.getBody ()) {
606+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
607+ if (!symbolOp)
608+ continue ;
609+
610+ StringRef oldSymName = symbolOp.getName ();
611+
612+ if (oldSymName == entryfn) {
613+ entryFn = &op;
614+ }
615+
616+ if (failed (updateSymbolAndAllUses (symbolOp, newMod, prevMod,
617+ lastUsedID))) {
618+ assert (0 && " failed to update all uses" );
619+ }
620+ SymbolTable::setSymbolVisibility (&op, SymbolTable::Visibility::Private);
621+ }
622+ prevMod.getBody ()->getOperations ().splice (prevMod.getBody ()->getOperations ().end (),
623+ newMod.getBody ()->getOperations ());
624+ return wrap (entryFn);
625+ }
626+
543627#pragma region xla::ifrt
544628
545629#pragma region xla::ifrt::Value
0 commit comments