diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h index f1353777f6ce9..03165895b85d0 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h @@ -153,6 +153,12 @@ class ThreadSafeModule { using GVPredicate = std::function; using GVModifier = std::function; +/// Clones teh given module onto the given context. +LLVM_ABI ThreadSafeModule +cloneToContext(const ThreadSafeModule &TSMW, ThreadSafeContext TSCtx, + GVPredicate ShouldCloneDef = GVPredicate(), + GVModifier UpdateClonedDefSource = GVModifier()); + /// Clones the given module on to a new context. LLVM_ABI ThreadSafeModule cloneToNewContext( const ThreadSafeModule &TSMW, GVPredicate ShouldCloneDef = GVPredicate(), diff --git a/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp b/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp index fadd53eee21b9..19c000e2472a8 100644 --- a/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp +++ b/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp @@ -14,51 +14,63 @@ namespace llvm { namespace orc { -ThreadSafeModule cloneToNewContext(const ThreadSafeModule &TSM, - GVPredicate ShouldCloneDef, - GVModifier UpdateClonedDefSource) { +ThreadSafeModule cloneToContext(const ThreadSafeModule &TSM, + ThreadSafeContext TSCtx, + GVPredicate ShouldCloneDef, + GVModifier UpdateClonedDefSource) { assert(TSM && "Can not clone null module"); if (!ShouldCloneDef) ShouldCloneDef = [](const GlobalValue &) { return true; }; - return TSM.withModuleDo([&](Module &M) { - SmallVector ClonedModuleBuffer; + // First copy the source module into a buffer. + std::string ModuleName; + SmallVector ClonedModuleBuffer; + TSM.withModuleDo([&](Module &M) { + ModuleName = M.getModuleIdentifier(); + std::set ClonedDefsInSrc; + ValueToValueMapTy VMap; + auto Tmp = CloneModule(M, VMap, [&](const GlobalValue *GV) { + if (ShouldCloneDef(*GV)) { + ClonedDefsInSrc.insert(const_cast(GV)); + return true; + } + return false; + }); - { - std::set ClonedDefsInSrc; - ValueToValueMapTy VMap; - auto Tmp = CloneModule(M, VMap, [&](const GlobalValue *GV) { - if (ShouldCloneDef(*GV)) { - ClonedDefsInSrc.insert(const_cast(GV)); - return true; - } - return false; - }); + if (UpdateClonedDefSource) + for (auto *GV : ClonedDefsInSrc) + UpdateClonedDefSource(*GV); - if (UpdateClonedDefSource) - for (auto *GV : ClonedDefsInSrc) - UpdateClonedDefSource(*GV); + BitcodeWriter BCWriter(ClonedModuleBuffer); + BCWriter.writeModule(*Tmp); + BCWriter.writeSymtab(); + BCWriter.writeStrtab(); + }); + + MemoryBufferRef ClonedModuleBufferRef( + StringRef(ClonedModuleBuffer.data(), ClonedModuleBuffer.size()), + "cloned module buffer"); - BitcodeWriter BCWriter(ClonedModuleBuffer); + // Then parse the buffer into the new Module. + auto M = TSCtx.withContextDo([&](LLVMContext *Ctx) { + assert(Ctx && "No LLVMContext provided"); + auto TmpM = cantFail(parseBitcodeFile(ClonedModuleBufferRef, *Ctx)); + TmpM->setModuleIdentifier(ModuleName); + return TmpM; + }); - BCWriter.writeModule(*Tmp); - BCWriter.writeSymtab(); - BCWriter.writeStrtab(); - } + return ThreadSafeModule(std::move(M), std::move(TSCtx)); +} - MemoryBufferRef ClonedModuleBufferRef( - StringRef(ClonedModuleBuffer.data(), ClonedModuleBuffer.size()), - "cloned module buffer"); - ThreadSafeContext NewTSCtx(std::make_unique()); +ThreadSafeModule cloneToNewContext(const ThreadSafeModule &TSM, + GVPredicate ShouldCloneDef, + GVModifier UpdateClonedDefSource) { + assert(TSM && "Can not clone null module"); - auto ClonedModule = NewTSCtx.withContextDo([&](LLVMContext *Ctx) { - auto TmpM = cantFail(parseBitcodeFile(ClonedModuleBufferRef, *Ctx)); - TmpM->setModuleIdentifier(M.getName()); - return TmpM; - }); - return ThreadSafeModule(std::move(ClonedModule), std::move(NewTSCtx)); - }); + ThreadSafeContext TSCtx(std::make_unique()); + return cloneToContext(TSM, std::move(TSCtx), std::move(ShouldCloneDef), + std::move(UpdateClonedDefSource)); } } // end namespace orc diff --git a/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp b/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp index adaa4d97ca5f4..bbb9e8d3d6a75 100644 --- a/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp @@ -7,6 +7,13 @@ //===----------------------------------------------------------------------===// #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + #include "gtest/gtest.h" #include @@ -18,6 +25,24 @@ using namespace llvm::orc; namespace { +const llvm::StringRef FooSrc = R"( + define void @foo() { + ret void + } +)"; + +static ThreadSafeModule parseModule(llvm::StringRef Source, + llvm::StringRef Name) { + auto Ctx = std::make_unique(); + SMDiagnostic Err; + auto M = parseIR(MemoryBufferRef(Source, Name), Err, *Ctx); + if (!M) { + Err.print("Testcase source failed to parse: ", errs()); + exit(1); + } + return ThreadSafeModule(std::move(M), std::move(Ctx)); +} + TEST(ThreadSafeModuleTest, ContextWhollyOwnedByOneModule) { // Test that ownership of a context can be transferred to a single // ThreadSafeModule. @@ -103,4 +128,28 @@ TEST(ThreadSafeModuleTest, ConsumingModuleDo) { TSM.consumingModuleDo([](std::unique_ptr M) {}); } +TEST(ThreadSafeModuleTest, CloneToNewContext) { + auto TSM1 = parseModule(FooSrc, "foo.ll"); + auto TSM2 = cloneToNewContext(TSM1); + TSM2.withModuleDo([&](Module &NewM) { + EXPECT_FALSE(verifyModule(NewM, &errs())); + TSM1.withModuleDo([&](Module &OrigM) { + EXPECT_NE(&NewM.getContext(), &OrigM.getContext()); + }); + }); +} + +TEST(ObjectFormatsTest, CloneToContext) { + auto TSM1 = parseModule(FooSrc, "foo.ll"); + + auto TSCtx = ThreadSafeContext(std::make_unique()); + auto TSM2 = cloneToContext(TSM1, TSCtx); + + TSM2.withModuleDo([&](Module &M) { + EXPECT_FALSE(verifyModule(M, &errs())); + TSCtx.withContextDo( + [&](LLVMContext *Ctx) { EXPECT_EQ(&M.getContext(), Ctx); }); + }); +} + } // end anonymous namespace