88
99#include " DeviceCompilation.h"
1010#include " ESIMD.h"
11+ #include " JITBinaryInfo.h"
12+ #include " translation/Translation.h"
1113
1214#include < clang/Basic/DiagnosticDriver.h>
1315#include < clang/Basic/Version.h>
1416#include < clang/CodeGen/CodeGenAction.h>
1517#include < clang/Driver/Compilation.h>
18+ #include < clang/Driver/Driver.h>
1619#include < clang/Driver/Options.h>
20+ #include < clang/Driver/ToolChain.h>
1721#include < clang/Frontend/ChainedDiagnosticConsumer.h>
1822#include < clang/Frontend/CompilerInstance.h>
1923#include < clang/Frontend/FrontendActions.h>
@@ -52,6 +56,7 @@ using namespace llvm::opt;
5256using namespace llvm ::sycl;
5357using namespace llvm ::module_split;
5458using namespace llvm ::util;
59+ using namespace llvm ::vfs;
5560using namespace jit_compiler ;
5661
5762#ifdef _GNU_SOURCE
@@ -313,7 +318,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
313318} // anonymous namespace
314319
315320static void adjustArgs (const InputArgList &UserArgList,
316- const std::string &DPCPPRoot,
321+ const std::string &DPCPPRoot, BinaryFormat Format,
317322 SmallVectorImpl<std::string> &CommandLine) {
318323 DerivedArgList DAL{UserArgList};
319324 const auto &OptTable = getDriverOptTable ();
@@ -326,6 +331,17 @@ static void adjustArgs(const InputArgList &UserArgList,
326331 // unused argument warning.
327332 DAL.AddFlagArg (nullptr , OptTable.getOption (OPT_Qunused_arguments));
328333
334+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
335+ auto [CPU, Features] =
336+ Translator::getTargetCPUAndFeatureAttrs (nullptr , " " , Format);
337+ (void )Features;
338+ StringRef OT = Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda"
339+ : " amdgcn-amd-amdhsa" ;
340+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_fsycl_targets_EQ), OT);
341+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_Xsycl_backend_EQ), OT);
342+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_offload_arch_EQ), CPU);
343+ }
344+
329345 ArgStringList ASL;
330346 for_each (DAL, [&DAL, &ASL](Arg *A) { A->render (DAL, ASL); });
331347 for_each (UserArgList,
@@ -362,10 +378,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
362378 });
363379}
364380
365- Expected<std::string>
366- jit_compiler::calculateHash (InMemoryFile SourceFile,
367- View<InMemoryFile> IncludeFiles,
368- const InputArgList &UserArgList) {
381+ Expected<std::string> jit_compiler::calculateHash (
382+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
383+ const InputArgList &UserArgList, BinaryFormat Format) {
369384 TimeTraceScope TTS{" calculateHash" };
370385
371386 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -374,7 +389,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
374389 }
375390
376391 SmallVector<std::string> CommandLine;
377- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
392+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
378393
379394 FixedCompilationDatabase DB{" ." , CommandLine};
380395 ClangTool Tool{DB, {SourceFile.Path }};
@@ -400,11 +415,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
400415 return createStringError (" Calculating source hash failed" );
401416}
402417
403- Expected<ModuleUPtr>
404- jit_compiler::compileDeviceCode (InMemoryFile SourceFile,
405- View<InMemoryFile> IncludeFiles,
406- const InputArgList &UserArgList,
407- std::string &BuildLog, LLVMContext &Context) {
418+ Expected<ModuleUPtr> jit_compiler::compileDeviceCode (
419+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
420+ const InputArgList &UserArgList, std::string &BuildLog,
421+ LLVMContext &Context, BinaryFormat Format) {
408422 TimeTraceScope TTS{" compileDeviceCode" };
409423
410424 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -413,7 +427,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
413427 }
414428
415429 SmallVector<std::string> CommandLine;
416- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
430+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
417431
418432 FixedCompilationDatabase DB{" ." , CommandLine};
419433 ClangTool Tool{DB, {SourceFile.Path }};
@@ -431,12 +445,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
431445 return createStringError (BuildLog);
432446}
433447
434- // This function is a simplified copy of the device library selection process in
435- // `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
436- // (no AoT, no third-party GPUs , no native CPU). Keep in sync!
448+ // This function is a simplified copy of the device library selection process
449+ // in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
450+ // GPU targets ( no AoT , no native CPU). Keep in sync!
437451static bool getDeviceLibraries (const ArgList &Args,
438452 SmallVectorImpl<std::string> &LibraryList,
439- DiagnosticsEngine &Diags) {
453+ DiagnosticsEngine &Diags, BinaryFormat Format) {
454+ // For CUDA/HIP we only need devicelib, early exit here.
455+ if (Format == BinaryFormat::PTX) {
456+ LibraryList.push_back (
457+ Args.MakeArgString (" devicelib-nvptx64-nvidia-cuda.bc" ));
458+ return false ;
459+ } else if (Format == BinaryFormat::AMDGCN) {
460+ LibraryList.push_back (Args.MakeArgString (" devicelib-amdgcn-amd-amdhsa.bc" ));
461+ return false ;
462+ }
463+
440464 struct DeviceLibOptInfo {
441465 StringRef DeviceLibName;
442466 StringRef DeviceLibOption;
@@ -541,7 +565,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
541565
542566Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
543567 const InputArgList &UserArgList,
544- std::string &BuildLog) {
568+ std::string &BuildLog,
569+ BinaryFormat Format) {
545570 TimeTraceScope TTS{" linkDeviceLibraries" };
546571
547572 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -556,11 +581,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
556581 /* ShouldOwnClient=*/ false );
557582
558583 SmallVector<std::string> LibNames;
559- bool FoundUnknownLib = getDeviceLibraries (UserArgList, LibNames, Diags);
584+ const bool FoundUnknownLib =
585+ getDeviceLibraries (UserArgList, LibNames, Diags, Format);
560586 if (FoundUnknownLib) {
561587 return createStringError (" Could not determine list of device libraries: %s" ,
562588 BuildLog.c_str ());
563589 }
590+ const bool IsCudaHIP =
591+ Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
592+ if (IsCudaHIP) {
593+ // Based on the OS and the format decide on the version of libspirv.
594+ // NOTE: this will be problematic if cross-compiling between OSes.
595+ std::string Libclc{" clc/" };
596+ Libclc.append (
597+ #ifdef _WIN32
598+ " remangled-l32-signed_char.libspirv-"
599+ #else
600+ " remangled-l64-signed_char.libspirv-"
601+ #endif
602+ );
603+ Libclc.append (Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda.bc"
604+ : " amdgcn-amd-amdhsa.bc" );
605+ LibNames.push_back (Libclc);
606+ }
564607
565608 LLVMContext &Context = Module.getContext ();
566609 for (const std::string &LibName : LibNames) {
@@ -578,6 +621,72 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
578621 }
579622 }
580623
624+ // For GPU targets we need to link against vendor provided libdevice.
625+ if (IsCudaHIP) {
626+ std::string Argv0 = DPCPPRoot + " /bin/clang++" ;
627+ Triple T{Module.getTargetTriple ()};
628+ IntrusiveRefCntPtr<OverlayFileSystem> OFS{
629+ new OverlayFileSystem{getRealFileSystem ()}};
630+ IntrusiveRefCntPtr<InMemoryFileSystem> VFS{new InMemoryFileSystem};
631+ std::string CppFileName{" a.cpp" };
632+ VFS->addFile (CppFileName, /* ModificationTime=*/ 0 ,
633+ MemoryBuffer::getMemBuffer (" " , " " ));
634+ OFS->pushOverlay (VFS);
635+ Driver D{Argv0, T.getTriple (), Diags, " dpcpp compiler driver" , OFS};
636+
637+ SmallVector<std::string> CommandLine;
638+ CommandLine.push_back (Argv0);
639+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
640+ CommandLine.push_back (CppFileName);
641+ SmallVector<const char *> CommandLineCStr (CommandLine.size ());
642+ llvm::transform (CommandLine, CommandLineCStr.begin (),
643+ [](const auto &S) { return S.c_str (); });
644+
645+ Compilation *C = D.BuildCompilation (CommandLineCStr);
646+ if (!C) {
647+ return createStringError (" Unable to construct driver for CUDA/HIP" );
648+ }
649+
650+ const ToolChain *OffloadTC =
651+ C->getSingleOffloadToolChain <Action::OFK_SYCL>();
652+ InputArgList EmptyArgList;
653+ auto Archs =
654+ D.getOffloadArchs (*C, EmptyArgList, Action::OFK_SYCL, OffloadTC);
655+ assert (Archs.size () == 1 &&
656+ " Offload toolchain should be configured to single architecture" );
657+ StringRef CPU = *Archs.begin ();
658+
659+ // Pass only `-march=` or `-mcpu=` with the GPU arch determined by the
660+ // driver to `getDeviceLibs`.
661+ DerivedArgList CPUArgList{EmptyArgList};
662+ if (Format == BinaryFormat::PTX) {
663+ CPUArgList.AddJoinedArg (nullptr , D.getOpts ().getOption (OPT_march_EQ),
664+ CPU);
665+ } else {
666+ CPUArgList.AddJoinedArg (nullptr , D.getOpts ().getOption (OPT_mcpu_EQ), CPU);
667+ }
668+
669+ SmallVector<ToolChain::BitCodeLibraryInfo, 12 > CommonDeviceLibs =
670+ OffloadTC->getDeviceLibs (CPUArgList, Action::OffloadKind::OFK_SYCL);
671+ if (CommonDeviceLibs.empty ()) {
672+ return createStringError (" Unable to find common device libraries" );
673+ }
674+
675+ for (auto &Lib : CommonDeviceLibs) {
676+ ModuleUPtr LibModule;
677+ if (auto Error =
678+ loadBitcodeLibrary (Lib.Path , Context).moveInto (LibModule)) {
679+ return Error;
680+ }
681+
682+ if (Linker::linkModules (Module, std::move (LibModule),
683+ Linker::LinkOnlyNeeded)) {
684+ return createStringError (" Unable to link device library %s: %s" ,
685+ Lib.Path .c_str (), BuildLog.c_str ());
686+ }
687+ }
688+ }
689+
581690 return Error::success ();
582691}
583692
0 commit comments