@@ -451,13 +451,44 @@ const char *const Runtimes = R"(
451451)" ;
452452
453453llvm::Expected<std::unique_ptr<Interpreter>>
454- Interpreter::create (std::unique_ptr<CompilerInstance> CI) {
454+ Interpreter::create (std::unique_ptr<CompilerInstance> CI,
455+ std::unique_ptr<CompilerInstance> DeviceCI) {
455456 llvm::Error Err = llvm::Error::success ();
456457 auto Interp =
457458 std::unique_ptr<Interpreter>(new Interpreter (std::move (CI), Err));
458459 if (Err)
459460 return std::move (Err);
460461
462+ CompilerInstance &HostCI = *(Interp->getCompilerInstance ());
463+
464+ if (DeviceCI) {
465+ Interp->DeviceAct = std::make_unique<IncrementalAction>(
466+ *DeviceCI, *Interp->TSCtx ->getContext (), Err, *Interp);
467+
468+ if (Err)
469+ return std::move (Err);
470+
471+ DeviceCI->ExecuteAction (*Interp->DeviceAct );
472+
473+ // avoid writing fat binary to disk using an in-memory virtual file system
474+ llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> IMVFS =
475+ std::make_unique<llvm::vfs::InMemoryFileSystem>();
476+ llvm::IntrusiveRefCntPtr<llvm::vfs::OverlayFileSystem> OverlayVFS =
477+ std::make_unique<llvm::vfs::OverlayFileSystem>(
478+ llvm::vfs::getRealFileSystem ());
479+ OverlayVFS->pushOverlay (IMVFS);
480+ HostCI.createFileManager (OverlayVFS);
481+
482+ auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
483+ std::move (DeviceCI), HostCI, IMVFS, Err,
484+ Interp->PTUs );
485+
486+ if (Err)
487+ return std::move (Err);
488+
489+ Interp->DeviceParser = std::move (DeviceParser);
490+ }
491+
461492 // Add runtime code and set a marker to hide it from user code. Undo will not
462493 // go through that.
463494 auto PTU = Interp->Parse (Runtimes);
@@ -472,29 +503,7 @@ Interpreter::create(std::unique_ptr<CompilerInstance> CI) {
472503llvm::Expected<std::unique_ptr<Interpreter>>
473504Interpreter::createWithCUDA (std::unique_ptr<CompilerInstance> CI,
474505 std::unique_ptr<CompilerInstance> DCI) {
475- // avoid writing fat binary to disk using an in-memory virtual file system
476- llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> IMVFS =
477- std::make_unique<llvm::vfs::InMemoryFileSystem>();
478- llvm::IntrusiveRefCntPtr<llvm::vfs::OverlayFileSystem> OverlayVFS =
479- std::make_unique<llvm::vfs::OverlayFileSystem>(
480- llvm::vfs::getRealFileSystem ());
481- OverlayVFS->pushOverlay (IMVFS);
482- CI->createFileManager (OverlayVFS);
483-
484- auto Interp = Interpreter::create (std::move (CI));
485- if (auto E = Interp.takeError ())
486- return std::move (E);
487-
488- llvm::Error Err = llvm::Error::success ();
489- auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
490- std::move (DCI), *(*Interp)->getCompilerInstance (), IMVFS, Err,
491- (*Interp)->PTUs );
492- if (Err)
493- return std::move (Err);
494-
495- (*Interp)->DeviceParser = std::move (DeviceParser);
496-
497- return Interp;
506+ return Interpreter::create (std::move (CI), std::move (DCI));
498507}
499508
500509const CompilerInstance *Interpreter::getCompilerInstance () const {
@@ -532,15 +541,16 @@ size_t Interpreter::getEffectivePTUSize() const {
532541
533542PartialTranslationUnit &
534543Interpreter::RegisterPTU (TranslationUnitDecl *TU,
535- std::unique_ptr<llvm::Module> M /* ={}*/ ) {
544+ std::unique_ptr<llvm::Module> M /* ={}*/ ,
545+ IncrementalAction *Action) {
536546 PTUs.emplace_back (PartialTranslationUnit ());
537547 PartialTranslationUnit &LastPTU = PTUs.back ();
538548 LastPTU.TUPart = TU;
539549
540550 if (!M)
541- M = GenModule ();
551+ M = GenModule (Action );
542552
543- assert ((!getCodeGen () || M) && " Must have a llvm::Module at this point" );
553+ assert ((!getCodeGen (Action ) || M) && " Must have a llvm::Module at this point" );
544554
545555 LastPTU.TheModule = std::move (M);
546556 LLVM_DEBUG (llvm::dbgs () << " compile-ptu " << PTUs.size () - 1
@@ -558,8 +568,21 @@ Interpreter::Parse(llvm::StringRef Code) {
558568 // included in the host compilation
559569 if (DeviceParser) {
560570 llvm::Expected<TranslationUnitDecl *> DeviceTU = DeviceParser->Parse (Code);
561- if (auto E = DeviceTU.takeError ())
571+ if (auto E = DeviceTU.takeError ()) {
562572 return std::move (E);
573+ }
574+
575+ auto *CudaParser = llvm::cast<IncrementalCUDADeviceParser>(DeviceParser.get ());
576+
577+ PartialTranslationUnit &DevicePTU = RegisterPTU (*DeviceTU, nullptr , DeviceAct.get ());
578+
579+ llvm::Expected<llvm::StringRef> PTX = CudaParser->GeneratePTX ();
580+ if (!PTX)
581+ return PTX.takeError ();
582+
583+ llvm::Error Err = CudaParser->GenerateFatbinary ();
584+ if (Err)
585+ return std::move (Err);
563586 }
564587
565588 // Tell the interpreter sliently ignore unused expressions since value
@@ -736,9 +759,9 @@ llvm::Error Interpreter::LoadDynamicLibrary(const char *name) {
736759 return llvm::Error::success ();
737760}
738761
739- std::unique_ptr<llvm::Module> Interpreter::GenModule () {
762+ std::unique_ptr<llvm::Module> Interpreter::GenModule (IncrementalAction *Action ) {
740763 static unsigned ID = 0 ;
741- if (CodeGenerator *CG = getCodeGen ()) {
764+ if (CodeGenerator *CG = getCodeGen (Action )) {
742765 // Clang's CodeGen is designed to work with a single llvm::Module. In many
743766 // cases for convenience various CodeGen parts have a reference to the
744767 // llvm::Module (TheModule or Module) which does not change when a new
@@ -760,8 +783,10 @@ std::unique_ptr<llvm::Module> Interpreter::GenModule() {
760783 return nullptr ;
761784}
762785
763- CodeGenerator *Interpreter::getCodeGen () const {
764- FrontendAction *WrappedAct = Act->getWrapped ();
786+ CodeGenerator *Interpreter::getCodeGen (IncrementalAction *Action) const {
787+ if (!Action)
788+ Action = Act.get ();
789+ FrontendAction *WrappedAct = Action->getWrapped ();
765790 if (!WrappedAct->hasIRSupport ())
766791 return nullptr ;
767792 return static_cast <CodeGenAction *>(WrappedAct)->getCodeGenerator ();
0 commit comments