@@ -64,86 +64,6 @@ using namespace llvm::util;
6464using namespace llvm ::vfs;
6565using namespace jit_compiler ;
6666
67- #ifdef _GNU_SOURCE
68- #include < dlfcn.h>
69- static char X; // Dummy symbol, used as an anchor for `dlinfo` below.
70- #endif
71-
72- #ifdef _WIN32
73- #include < filesystem> // For std::filesystem::path ( C++17 only )
74- #include < shlwapi.h> // For PathRemoveFileSpec
75- #include < windows.h> // For GetModuleFileName, HMODULE, DWORD, MAX_PATH
76-
77- // cribbed from sycl/source/detail/os_util.cpp
78- using OSModuleHandle = intptr_t ;
79- static constexpr OSModuleHandle ExeModuleHandle = -1 ;
80- static OSModuleHandle getOSModuleHandle (const void *VirtAddr) {
81- HMODULE PhModule;
82- DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
83- GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT;
84- auto LpModuleAddr = reinterpret_cast <LPCSTR>(VirtAddr);
85- if (!GetModuleHandleExA (Flag, LpModuleAddr, &PhModule)) {
86- // Expect the caller to check for zero and take
87- // necessary action
88- return 0 ;
89- }
90- if (PhModule == GetModuleHandleA (nullptr ))
91- return ExeModuleHandle;
92- return reinterpret_cast <OSModuleHandle>(PhModule);
93- }
94-
95- // cribbed from sycl/source/detail/os_util.cpp
96- // / Returns an absolute path where the object was found.
97- std::wstring getCurrentDSODir () {
98- wchar_t Path[MAX_PATH];
99- auto Handle = getOSModuleHandle (reinterpret_cast <void *>(&getCurrentDSODir));
100- DWORD Ret = GetModuleFileName (
101- reinterpret_cast <HMODULE>(ExeModuleHandle == Handle ? 0 : Handle), Path,
102- MAX_PATH);
103- assert (Ret < MAX_PATH && " Path is longer than MAX_PATH?" );
104- assert (Ret > 0 && " GetModuleFileName failed" );
105- (void )Ret;
106-
107- BOOL RetCode = PathRemoveFileSpec (Path);
108- assert (RetCode && " PathRemoveFileSpec failed" );
109- (void )RetCode;
110-
111- return Path;
112- }
113- #endif // _WIN32
114-
115- static constexpr auto InvalidDPCPPRoot = " <invalid>" ;
116-
117- static const std::string &getDPCPPRoot () {
118- thread_local std::string DPCPPRoot;
119-
120- if (!DPCPPRoot.empty ()) {
121- return DPCPPRoot;
122- }
123- DPCPPRoot = InvalidDPCPPRoot;
124-
125- #ifdef _GNU_SOURCE
126- static constexpr auto JITLibraryPathSuffix = " /lib/libsycl-jit.so" ;
127- Dl_info Info;
128- if (dladdr (&X, &Info)) {
129- std::string LoadedLibraryPath = Info.dli_fname ;
130- auto Pos = LoadedLibraryPath.rfind (JITLibraryPathSuffix);
131- if (Pos != std::string::npos) {
132- DPCPPRoot = LoadedLibraryPath.substr (0 , Pos);
133- }
134- }
135- #endif // _GNU_SOURCE
136-
137- #ifdef _WIN32
138- DPCPPRoot = std::filesystem::path (getCurrentDSODir ()).parent_path ().string ();
139- #endif // _WIN32
140-
141- // TODO: Implemenent other means of determining the DPCPP root, e.g.
142- // evaluating the `CMPLR_ROOT` env.
143-
144- return DPCPPRoot;
145- }
146-
14767namespace {
14868
14969class HashPreprocessedAction : public PreprocessorFrontendAction {
@@ -252,6 +172,30 @@ class SYCLToolchain {
252172 return TI.run ();
253173 }
254174
175+ Expected<ModuleUPtr> loadBitcodeLibrary (StringRef LibPath,
176+ LLVMContext &Context) {
177+ auto FS = llvm::makeIntrusiveRefCnt<llvm::vfs::OverlayFileSystem>(
178+ llvm::vfs::getRealFileSystem ());
179+ FS->pushOverlay (ToolchainFS);
180+
181+ auto MemBuf = FS->getBufferForFile (LibPath, /* FileSize*/ -1 ,
182+ /* RequiresNullTerminator*/ false );
183+ if (!MemBuf) {
184+ return createStringError (" Error opening file %s: %s" , LibPath.data (),
185+ MemBuf.getError ().message ().c_str ());
186+ }
187+
188+ SMDiagnostic Diag;
189+ ModuleUPtr Lib = parseIR (*MemBuf->get (), Diag, Context);
190+ if (!Lib) {
191+ std::string DiagMsg;
192+ raw_string_ostream SOS (DiagMsg);
193+ Diag.print (/* ProgName=*/ nullptr , SOS);
194+ return createStringError (DiagMsg);
195+ }
196+ return std::move (Lib);
197+ }
198+
255199 std::string_view getClangXXExe () const { return ClangXXExe; }
256200
257201private:
@@ -516,30 +460,12 @@ static bool getDeviceLibraries(const ArgList &Args,
516460 return FoundUnknownLib;
517461}
518462
519- static Expected<ModuleUPtr> loadBitcodeLibrary (StringRef LibPath,
520- LLVMContext &Context) {
521- SMDiagnostic Diag;
522- ModuleUPtr Lib = parseIRFile (LibPath, Diag, Context);
523- if (!Lib) {
524- std::string DiagMsg;
525- raw_string_ostream SOS (DiagMsg);
526- Diag.print (/* ProgName=*/ nullptr , SOS);
527- return createStringError (DiagMsg);
528- }
529- return std::move (Lib);
530- }
531-
532463Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
533464 const InputArgList &UserArgList,
534465 std::string &BuildLog,
535466 BinaryFormat Format) {
536467 TimeTraceScope TTS{" linkDeviceLibraries" };
537468
538- const std::string &DPCPPRoot = getDPCPPRoot ();
539- if (DPCPPRoot == InvalidDPCPPRoot) {
540- return createStringError (" Could not locate DPCPP root directory" );
541- }
542-
543469 IntrusiveRefCntPtr<DiagnosticIDs> DiagID{new DiagnosticIDs};
544470 DiagnosticOptions DiagOpts;
545471 ClangDiagnosticWrapper Wrapper (BuildLog, &DiagOpts);
@@ -573,10 +499,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
573499
574500 LLVMContext &Context = Module.getContext ();
575501 for (const std::string &LibName : LibNames) {
576- std::string LibPath = DPCPPRoot + " /lib/" + LibName;
502+ std::string LibPath =
503+ (jit_compiler::ToolchainPrefix + " /lib/" + LibName).str ();
577504
578505 ModuleUPtr LibModule;
579- if (auto Error = loadBitcodeLibrary (LibPath, Context).moveInto (LibModule)) {
506+ if (auto Error = SYCLToolchain::instance ()
507+ .loadBitcodeLibrary (LibPath, Context)
508+ .moveInto (LibModule)) {
580509 return Error;
581510 }
582511
@@ -590,14 +519,16 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
590519 // For GPU targets we need to link against vendor provided libdevice.
591520 if (IsCudaHIP) {
592521 Triple T{Module.getTargetTriple ()};
593- Driver D{(Twine (DPCPPRoot) + " /bin/clang++" ).str (), T.getTriple (), Diags};
522+ Driver D{(jit_compiler::ToolchainPrefix + " /bin/clang++" ).str (),
523+ T.getTriple (), Diags};
594524 auto [CPU, Features] =
595525 Translator::getTargetCPUAndFeatureAttrs (&Module, " " , Format);
596526 (void )Features;
597527 // Helper lambda to link modules.
598528 auto LinkInLib = [&](const StringRef LibDevice) -> Error {
599529 ModuleUPtr LibDeviceModule;
600- if (auto Error = loadBitcodeLibrary (LibDevice, Context)
530+ if (auto Error = SYCLToolchain::instance ()
531+ .loadBitcodeLibrary (LibDevice, Context)
601532 .moveInto (LibDeviceModule)) {
602533 return Error;
603534 }
@@ -831,16 +762,14 @@ jit_compiler::performPostLink(ModuleUPtr Module,
831762 }
832763
833764 if (IsBF16DeviceLibUsed) {
834- const std::string &DPCPPRoot = getDPCPPRoot ();
835- if (DPCPPRoot == InvalidDPCPPRoot) {
836- return createStringError (" Could not locate DPCPP root directory" );
837- }
838-
839765 auto &Ctx = Modules.front ()->getContext ();
840766 auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error {
841- std::string LibPath = DPCPPRoot + " /lib/" + LibName;
767+ std::string LibPath =
768+ (jit_compiler::ToolchainPrefix + " /lib/" + LibName).str ();
842769 ModuleUPtr LibModule;
843- if (auto Error = loadBitcodeLibrary (LibPath, Ctx).moveInto (LibModule)) {
770+ if (auto Error = SYCLToolchain::instance ()
771+ .loadBitcodeLibrary (LibPath, Ctx)
772+ .moveInto (LibModule)) {
844773 return Error;
845774 }
846775
0 commit comments