@@ -376,6 +376,16 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
376376 return wrap (res);
377377}
378378
379+ #include " llvm/IRReader/IRReader.h"
380+ extern " C" MlirModule ConvertLLVMStrToMLIR (const char * lmod, MlirContext cctx) {
381+ LLVMContext Context;
382+ SMDiagnostic Err;
383+ auto llvmModule = llvm::parseIR (llvm::MemoryBufferRef (lmod, " conversion" ), Err, Context);
384+ mlir::MLIRContext &context = *unwrap (cctx);
385+ auto res = mlir::translateLLVMIRToModule (std::move (llvmModule), &context, /* emitExpensiveWarnings*/ false , /* dropDICompositeElements*/ false ).release ();
386+ return wrap (res);
387+ }
388+
379389
380390/* Note that this */
381391extern " C" xla::PjRtLoadedExecutable* ClientCompile (PjRtClient * client, MlirModule cmod) {
@@ -460,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
460470 context.loadDialect <mlir::stablehlo::StablehloDialect>();
461471 context.loadDialect <mlir::chlo::ChloDialect>();
462472}
473+
474+ #include " mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
475+ #include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
476+ #include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
463477extern " C" void InitializeRegistryAndPasses (MlirDialectRegistry creg) {
464478 mlir::DialectRegistry ®istry = *unwrap (creg);
465479
@@ -503,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
503517 mlir::affine::registerAffinePasses ();
504518 mlir::registerReconcileUnrealizedCasts ();
505519
520+ mlir::registerLLVMDialectImport (registry);
521+ mlir::registerNVVMDialectImport (registry);
522+
523+ mlir::LLVM::registerInlinerInterface (registry);
524+
506525/*
507526 registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
508527 LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -530,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
530549 mlir::enzyme::registerEnzymeJaxTransformExtension (registry);
531550}
532551
552+
553+ // / Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
554+ // / suffix in `lastUsedID`.
555+ static mlir::StringAttr renameSymbol (llvm::StringRef oldSymName,
556+ unsigned &lastUsedID,
557+ mlir::ModuleOp source,
558+ mlir::ModuleOp target) {
559+ using namespace llvm ;
560+ using namespace mlir ;
561+ SmallString<64 > newSymName (oldSymName);
562+ newSymName.push_back (' _' );
563+ while (true ) {
564+ auto possible = newSymName + Twine (++lastUsedID);
565+ if (!SymbolTable::lookupSymbolIn (source, possible.str ()) && !SymbolTable::lookupSymbolIn (target, possible.str ())) {
566+ return StringAttr::get (target.getContext (), possible);
567+ }
568+ }
569+ }
570+
571+
572+ // / Checks if a symbol with the same name as `op` already exists in `source`.
573+ // / If so, renames `op` and updates all its references in `target`.
574+ static mlir::LogicalResult
575+ updateSymbolAndAllUses (mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
576+ unsigned &lastUsedID) {
577+ using namespace llvm ;
578+ using namespace mlir ;
579+
580+ auto opName = op.getName ().str ();
581+
582+ if (!SymbolTable::lookupSymbolIn (target, opName)) {
583+ return success ();
584+ }
585+
586+ StringAttr newSymName =
587+ renameSymbol (opName, lastUsedID, source, target);
588+
589+ if (failed (SymbolTable::replaceAllSymbolUses (op, newSymName, source)))
590+ return op.emitError (" unable to update all symbol uses for " )
591+ << opName << " to " << newSymName;
592+
593+ SymbolTable::setSymbolName (op, newSymName);
594+ return success ();
595+ }
596+
597+ extern " C" MlirOperation LinkInModule (MlirModule prevModC, MlirModule newModC, const char * entryfn) {
598+ auto prevMod = cast<ModuleOp>(*unwrap (prevModC));
599+ auto newMod = cast<ModuleOp>(*unwrap (newModC));
600+
601+ Operation* entryFn = nullptr ;
602+
603+ unsigned lastUsedID = 0 ;
604+
605+ for (auto &op : *newMod.getBody ()) {
606+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
607+ if (!symbolOp)
608+ continue ;
609+
610+ StringRef oldSymName = symbolOp.getName ();
611+
612+ if (oldSymName == entryfn) {
613+ entryFn = &op;
614+ }
615+
616+ if (failed (updateSymbolAndAllUses (symbolOp, newMod, prevMod,
617+ lastUsedID))) {
618+ assert (0 && " failed to update all uses" );
619+ }
620+ SymbolTable::setSymbolVisibility (&op, SymbolTable::Visibility::Private);
621+ }
622+ prevMod.getBody ()->getOperations ().splice (prevMod.getBody ()->getOperations ().end (),
623+ newMod.getBody ()->getOperations ());
624+ return wrap (entryFn);
625+ }
626+
533627#pragma region xla::ifrt
534628
535629#pragma region xla::ifrt::Value
0 commit comments