Skip to content

Commit 19ea12b

Browse files
authored
[SYCL] Implement device library linking for runtime compilation (#15810)
Mimics `clang::driver::tools::SYCL::getDeviceLibraries` assuming a SPIR-V target (= no AoT, no third-party GPUs, no native CPU). Same as for the compilation step, warning/error reporting is still rudimentary and will be improved in a future PR. --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent 6456fe8 commit 19ea12b

File tree

5 files changed

+310
-31
lines changed

5 files changed

+310
-31
lines changed

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ add_llvm_library(sycl-jit
1919
BitReader
2020
Core
2121
Support
22+
Option
2223
Analysis
2324
IPO
2425
TransformUtils
2526
Passes
27+
IRReader
2628
Linker
2729
ScalarOpts
2830
InstCombine

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,22 +237,30 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
237237
extern "C" KF_EXPORT_SYMBOL JITResult
238238
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
239239
View<const char *> UserArgs) {
240-
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgs);
240+
auto UserArgListOrErr = parseUserArgs(UserArgs);
241+
if (!UserArgListOrErr) {
242+
return errorToFusionResult(UserArgListOrErr.takeError(),
243+
"Parsing of user arguments failed");
244+
}
245+
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);
246+
247+
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
241248
if (!ModuleOrErr) {
242249
return errorToFusionResult(ModuleOrErr.takeError(),
243250
"Device compilation failed");
244251
}
245-
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
246252

247-
SYCLKernelInfo Kernel;
248-
auto Error = translation::KernelTranslator::translateKernel(
249-
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV);
253+
std::unique_ptr<llvm::LLVMContext> Context;
254+
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
255+
Context.reset(&Module->getContext());
250256

251-
auto *LLVMCtx = &Module->getContext();
252-
Module.reset();
253-
delete LLVMCtx;
257+
if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
258+
return errorToFusionResult(std::move(Error), "Device linking failed");
259+
}
254260

255-
if (Error) {
261+
SYCLKernelInfo Kernel;
262+
if (auto Error = translation::KernelTranslator::translateKernel(
263+
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV)) {
256264
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
257265
}
258266

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 239 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,28 @@
88

99
#include "DeviceCompilation.h"
1010

11+
#include <clang/Basic/DiagnosticDriver.h>
1112
#include <clang/Basic/Version.h>
1213
#include <clang/CodeGen/CodeGenAction.h>
1314
#include <clang/Driver/Compilation.h>
15+
#include <clang/Driver/Options.h>
1416
#include <clang/Frontend/CompilerInstance.h>
17+
#include <clang/Frontend/TextDiagnosticBuffer.h>
1518
#include <clang/Tooling/CompilationDatabase.h>
1619
#include <clang/Tooling/Tooling.h>
1720

21+
#include <llvm/IRReader/IRReader.h>
22+
#include <llvm/Linker/Linker.h>
23+
24+
#include <array>
25+
26+
using namespace clang;
27+
using namespace clang::tooling;
28+
using namespace clang::driver;
29+
using namespace clang::driver::options;
30+
using namespace llvm;
31+
using namespace llvm::opt;
32+
1833
#ifdef _GNU_SOURCE
1934
#include <dlfcn.h>
2035
static char X; // Dummy symbol, used as an anchor for `dlinfo` below.
@@ -96,9 +111,6 @@ static const std::string &getDPCPPRoot() {
96111
}
97112

98113
namespace {
99-
using namespace clang;
100-
using namespace clang::tooling;
101-
using namespace clang::driver;
102114

103115
struct GetLLVMModuleAction : public ToolAction {
104116
// Code adapted from `FrontendActionFactory::runInvocation`.
@@ -143,23 +155,37 @@ struct GetLLVMModuleAction : public ToolAction {
143155

144156
} // anonymous namespace
145157

146-
llvm::Expected<std::unique_ptr<llvm::Module>>
158+
Expected<std::unique_ptr<llvm::Module>>
147159
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
148160
View<InMemoryFile> IncludeFiles,
149-
View<const char *> UserArgs) {
161+
const InputArgList &UserArgList) {
150162
const std::string &DPCPPRoot = getDPCPPRoot();
151163
if (DPCPPRoot == InvalidDPCPPRoot) {
152-
return llvm::createStringError("Could not locate DPCPP root directory");
164+
return createStringError("Could not locate DPCPP root directory");
153165
}
154166

155-
SmallVector<std::string> CommandLine = {"-fsycl-device-only"};
156-
// TODO: Allow instrumentation again when device library linking is
157-
// implemented.
158-
CommandLine.push_back("-fno-sycl-instrument-device-code");
159-
CommandLine.append(UserArgs.begin(), UserArgs.end());
160-
clang::tooling::FixedCompilationDatabase DB{".", CommandLine};
167+
DerivedArgList DAL{UserArgList};
168+
const auto &OptTable = getDriverOptTable();
169+
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only));
170+
DAL.AddJoinedArg(
171+
nullptr, OptTable.getOption(OPT_resource_dir_EQ),
172+
(DPCPPRoot + "/lib/clang/" + Twine(CLANG_VERSION_MAJOR)).str());
173+
for (auto *Arg : UserArgList) {
174+
DAL.append(Arg);
175+
}
176+
// Remove args that will trigger an unused command line argument warning for
177+
// the FrontendAction invocation, but are handled later (e.g. during device
178+
// linking).
179+
DAL.eraseArg(OPT_fsycl_device_lib_EQ);
180+
DAL.eraseArg(OPT_fno_sycl_device_lib_EQ);
181+
182+
SmallVector<std::string> CommandLine;
183+
for (auto *Arg : DAL) {
184+
CommandLine.emplace_back(Arg->getAsString(DAL));
185+
}
161186

162-
clang::tooling::ClangTool Tool{DB, {SourceFile.Path}};
187+
FixedCompilationDatabase DB{".", CommandLine};
188+
ClangTool Tool{DB, {SourceFile.Path}};
163189

164190
// Set up in-memory filesystem.
165191
Tool.mapVirtualFile(SourceFile.Path, SourceFile.Contents);
@@ -170,17 +196,14 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
170196
// Reset argument adjusters to drop the `-fsyntax-only` flag which is added by
171197
// default by this API.
172198
Tool.clearArgumentsAdjusters();
173-
// Then, modify argv[0] and set the resource directory so that the driver
174-
// picks up the correct SYCL environment.
199+
// Then, modify argv[0] so that the driver picks up the correct SYCL
200+
// environment. We've already set the resource directory above.
175201
Tool.appendArgumentsAdjuster(
176202
[&DPCPPRoot](const CommandLineArguments &Args,
177203
StringRef Filename) -> CommandLineArguments {
178204
(void)Filename;
179205
CommandLineArguments NewArgs = Args;
180206
NewArgs[0] = (Twine(DPCPPRoot) + "/bin/clang++").str();
181-
NewArgs.push_back((Twine("-resource-dir=") + DPCPPRoot + "/lib/clang/" +
182-
Twine(CLANG_VERSION_MAJOR))
183-
.str());
184207
return NewArgs;
185208
});
186209

@@ -190,5 +213,202 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
190213
}
191214

192215
// TODO: Capture compiler errors from the ClangTool.
193-
return llvm::createStringError("Unable to obtain LLVM module");
216+
return createStringError("Unable to obtain LLVM module");
217+
}
218+
219+
// This function is a simplified copy of the device library selection process in
220+
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
221+
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
222+
static SmallVector<std::string, 8>
223+
getDeviceLibraries(const ArgList &Args, DiagnosticsEngine &Diags) {
224+
struct DeviceLibOptInfo {
225+
StringRef DeviceLibName;
226+
StringRef DeviceLibOption;
227+
};
228+
229+
// Currently, all SYCL device libraries will be linked by default.
230+
llvm::StringMap<bool> DeviceLibLinkInfo = {
231+
{"libc", true}, {"libm-fp32", true}, {"libm-fp64", true},
232+
{"libimf-fp32", true}, {"libimf-fp64", true}, {"libimf-bf16", true},
233+
{"libm-bfloat16", true}, {"internal", true}};
234+
235+
// If -fno-sycl-device-lib is specified, its values will be used to exclude
236+
// linkage of libraries specified by DeviceLibLinkInfo. Linkage of "internal"
237+
// libraries cannot be affected via -fno-sycl-device-lib.
238+
bool ExcludeDeviceLibs = false;
239+
240+
if (Arg *A = Args.getLastArg(OPT_fsycl_device_lib_EQ,
241+
OPT_fno_sycl_device_lib_EQ)) {
242+
if (A->getValues().size() == 0) {
243+
Diags.Report(diag::warn_drv_empty_joined_argument)
244+
<< A->getAsString(Args);
245+
} else {
246+
if (A->getOption().matches(OPT_fno_sycl_device_lib_EQ)) {
247+
ExcludeDeviceLibs = true;
248+
}
249+
250+
for (StringRef Val : A->getValues()) {
251+
if (Val == "all") {
252+
for (const auto &K : DeviceLibLinkInfo.keys()) {
253+
DeviceLibLinkInfo[K] = (K == "internal") || !ExcludeDeviceLibs;
254+
}
255+
break;
256+
}
257+
auto LinkInfoIter = DeviceLibLinkInfo.find(Val);
258+
if (LinkInfoIter == DeviceLibLinkInfo.end() || Val == "internal") {
259+
Diags.Report(diag::err_drv_unsupported_option_argument)
260+
<< A->getSpelling() << Val;
261+
}
262+
DeviceLibLinkInfo[Val] = !ExcludeDeviceLibs;
263+
}
264+
}
265+
}
266+
267+
using SYCLDeviceLibsList = SmallVector<DeviceLibOptInfo, 5>;
268+
269+
const SYCLDeviceLibsList SYCLDeviceWrapperLibs = {
270+
{"libsycl-crt", "libc"},
271+
{"libsycl-complex", "libm-fp32"},
272+
{"libsycl-complex-fp64", "libm-fp64"},
273+
{"libsycl-cmath", "libm-fp32"},
274+
{"libsycl-cmath-fp64", "libm-fp64"},
275+
{"libsycl-imf", "libimf-fp32"},
276+
{"libsycl-imf-fp64", "libimf-fp64"},
277+
{"libsycl-imf-bf16", "libimf-bf16"}};
278+
// ITT annotation libraries are linked in separately whenever the device
279+
// code instrumentation is enabled.
280+
const SYCLDeviceLibsList SYCLDeviceAnnotationLibs = {
281+
{"libsycl-itt-user-wrappers", "internal"},
282+
{"libsycl-itt-compiler-wrappers", "internal"},
283+
{"libsycl-itt-stubs", "internal"}};
284+
285+
SmallVector<std::string, 8> LibraryList;
286+
StringRef LibSuffix = ".bc";
287+
auto AddLibraries = [&](const SYCLDeviceLibsList &LibsList) {
288+
for (const DeviceLibOptInfo &Lib : LibsList) {
289+
if (!DeviceLibLinkInfo[Lib.DeviceLibOption]) {
290+
continue;
291+
}
292+
SmallString<128> LibName(Lib.DeviceLibName);
293+
llvm::sys::path::replace_extension(LibName, LibSuffix);
294+
LibraryList.push_back(Args.MakeArgString(LibName));
295+
}
296+
};
297+
298+
AddLibraries(SYCLDeviceWrapperLibs);
299+
300+
if (Args.hasFlag(OPT_fsycl_instrument_device_code,
301+
OPT_fno_sycl_instrument_device_code, false)) {
302+
AddLibraries(SYCLDeviceAnnotationLibs);
303+
}
304+
305+
return LibraryList;
306+
}
307+
308+
Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
309+
const InputArgList &UserArgList) {
310+
const std::string &DPCPPRoot = getDPCPPRoot();
311+
if (DPCPPRoot == InvalidDPCPPRoot) {
312+
return createStringError("Could not locate DPCPP root directory");
313+
}
314+
315+
// TODO: Seems a bit excessive to set up this machinery for one warning and
316+
// one error. Rethink when implementing the build log/error reporting as
317+
// mandated by the extension.
318+
IntrusiveRefCntPtr<DiagnosticIDs> DiagID{new DiagnosticIDs};
319+
IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts{new DiagnosticOptions};
320+
TextDiagnosticBuffer *DiagBuffer = new TextDiagnosticBuffer;
321+
DiagnosticsEngine Diags(DiagID, DiagOpts, DiagBuffer);
322+
323+
auto LibNames = getDeviceLibraries(UserArgList, Diags);
324+
if (std::distance(DiagBuffer->err_begin(), DiagBuffer->err_end()) > 0) {
325+
std::string DiagMsg;
326+
raw_string_ostream SOS{DiagMsg};
327+
interleave(
328+
DiagBuffer->err_begin(), DiagBuffer->err_end(),
329+
[&](const auto &D) { SOS << D.second; }, [&]() { SOS << '\n'; });
330+
return createStringError("Could not determine list of device libraries: %s",
331+
DiagMsg.c_str());
332+
}
333+
// TODO: Add warnings to build log.
334+
335+
LLVMContext &Context = Module.getContext();
336+
for (const std::string &LibName : LibNames) {
337+
std::string LibPath = DPCPPRoot + "/lib/" + LibName;
338+
339+
SMDiagnostic Diag;
340+
std::unique_ptr<llvm::Module> Lib = parseIRFile(LibPath, Diag, Context);
341+
if (!Lib) {
342+
std::string DiagMsg;
343+
raw_string_ostream SOS(DiagMsg);
344+
Diag.print(/*ProgName=*/nullptr, SOS);
345+
return createStringError(DiagMsg);
346+
}
347+
348+
if (Linker::linkModules(Module, std::move(Lib), Linker::LinkOnlyNeeded)) {
349+
// TODO: Obtain detailed error message from the context's diagnostics
350+
// handler.
351+
return createStringError("Unable to link device library: %s",
352+
LibPath.c_str());
353+
}
354+
}
355+
356+
return Error::success();
357+
}
358+
359+
Expected<InputArgList>
360+
jit_compiler::parseUserArgs(View<const char *> UserArgs) {
361+
unsigned MissingArgIndex, MissingArgCount;
362+
auto UserArgsRef = UserArgs.to<ArrayRef>();
363+
auto AL = getDriverOptTable().ParseArgs(UserArgsRef, MissingArgIndex,
364+
MissingArgCount);
365+
if (MissingArgCount) {
366+
return createStringError(
367+
"User option '%s' at index %d is missing an argument",
368+
UserArgsRef[MissingArgIndex], MissingArgIndex);
369+
}
370+
371+
// Check for unsupported options.
372+
// TODO: There are probably more, e.g. requesting non-SPIR-V targets.
373+
{
374+
// -fsanitize=address
375+
bool IsDeviceAsanEnabled = false;
376+
if (Arg *A = AL.getLastArg(OPT_fsanitize_EQ, OPT_fno_sanitize_EQ)) {
377+
if (A->getOption().matches(OPT_fsanitize_EQ) &&
378+
A->getValues().size() == 1) {
379+
std::string SanitizeVal = A->getValue();
380+
IsDeviceAsanEnabled = SanitizeVal == "address";
381+
}
382+
} else {
383+
// User can pass -fsanitize=address to device compiler via
384+
// -Xsycl-target-frontend.
385+
auto SyclFEArg = AL.getAllArgValues(OPT_Xsycl_frontend);
386+
IsDeviceAsanEnabled = (std::count(SyclFEArg.begin(), SyclFEArg.end(),
387+
"-fsanitize=address") > 0);
388+
if (!IsDeviceAsanEnabled) {
389+
auto SyclFEArgEq = AL.getAllArgValues(OPT_Xsycl_frontend_EQ);
390+
IsDeviceAsanEnabled =
391+
(std::count(SyclFEArgEq.begin(), SyclFEArgEq.end(),
392+
"-fsanitize=address") > 0);
393+
}
394+
395+
// User can also enable asan for SYCL device via -Xarch_device option.
396+
if (!IsDeviceAsanEnabled) {
397+
auto DeviceArchVals = AL.getAllArgValues(OPT_Xarch_device);
398+
for (auto DArchVal : DeviceArchVals) {
399+
if (DArchVal.find("-fsanitize=address") != std::string::npos) {
400+
IsDeviceAsanEnabled = true;
401+
break;
402+
}
403+
}
404+
}
405+
}
406+
407+
if (IsDeviceAsanEnabled) {
408+
return createStringError(
409+
"Device ASAN is not supported for runtime compilation");
410+
}
411+
}
412+
413+
return Expected<InputArgList>{std::move(AL)};
194414
}

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "View.h"
1414

1515
#include <llvm/IR/Module.h>
16+
#include <llvm/Option/ArgList.h>
1617
#include <llvm/Support/Error.h>
1718

1819
#include <memory>
@@ -21,7 +22,13 @@ namespace jit_compiler {
2122

2223
llvm::Expected<std::unique_ptr<llvm::Module>>
2324
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
24-
View<const char *> UserArgs);
25+
const llvm::opt::InputArgList &UserArgList);
26+
27+
llvm::Error linkDeviceLibraries(llvm::Module &Module,
28+
const llvm::opt::InputArgList &UserArgList);
29+
30+
llvm::Expected<llvm::opt::InputArgList>
31+
parseUserArgs(View<const char *> UserArgs);
2532

2633
} // namespace jit_compiler
2734

0 commit comments

Comments
 (0)