Skip to content

Commit 6e892f9

Browse files
committed
[SYCL] Add plumbing to implement kernel compiler extension with libtooling and sycl-jit
Signed-off-by: Julian Oppermann <[email protected]>
1 parent d3c5733 commit 6e892f9

File tree

15 files changed

+424
-0
lines changed

15 files changed

+424
-0
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,11 @@ struct SYCLKernelInfo {
349349
: Name{KernelName}, Args{NumArgs}, Attributes{}, NDR{}, BinaryInfo{} {}
350350
};
351351

352+
struct IncludePair {
353+
const char *Path;
354+
const char *Contents;
355+
};
356+
352357
} // namespace jit_compiler
353358

354359
#endif // SYCL_FUSION_COMMON_KERNEL_H

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_llvm_library(sycl-jit
77
lib/fusion/FusionHelper.cpp
88
lib/fusion/JITContext.cpp
99
lib/fusion/ModuleHelper.cpp
10+
lib/rtc/DeviceCompilation.cpp
1011
lib/helper/ConfigHelper.cpp
1112

1213
SHARED
@@ -29,6 +30,14 @@ add_llvm_library(sycl-jit
2930
TargetParser
3031
MC
3132
${LLVM_TARGETS_TO_BUILD}
33+
34+
LINK_LIBS
35+
clangBasic
36+
clangDriver
37+
clangFrontend
38+
clangCodeGen
39+
clangTooling
40+
clangSerialization
3241
)
3342

3443
target_compile_options(sycl-jit PRIVATE ${SYCL_JIT_WARNING_FLAGS})
@@ -40,6 +49,8 @@ target_include_directories(sycl-jit
4049
SYSTEM PRIVATE
4150
${LLVM_MAIN_INCLUDE_DIR}
4251
${LLVM_SPIRV_INCLUDE_DIRS}
52+
${CMAKE_SOURCE_DIR}/../clang/include
53+
${CMAKE_BINARY_DIR}/tools/clang/include
4354
)
4455
target_include_directories(sycl-jit
4556
PUBLIC

sycl-jit/jit-compiler/include/KernelFusion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ JITResult materializeSpecConstants(const char *KernelName,
6666
jit_compiler::SYCLKernelBinaryInfo &BinInfo,
6767
View<unsigned char> SpecConstBlob);
6868

69+
JITResult compileSYCL(const char *SYCLSource, View<IncludePair> IncludePairs,
70+
View<const char *> UserArgs, const char *DPCPPRoot);
71+
6972
/// Clear all previously set options.
7073
void resetJITConfiguration();
7174

sycl-jit/jit-compiler/ld-version-script.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/* Export the library entry points */
44
fuseKernels;
55
materializeSpecConstants;
6+
compileSYCL;
67
resetJITConfiguration;
78
addToJITConfiguration;
89

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "fusion/FusionPipeline.h"
1515
#include "helper/ConfigHelper.h"
1616
#include "helper/ErrorHandling.h"
17+
#include "rtc/DeviceCompilation.h"
1718
#include "translation/KernelTranslation.h"
1819
#include "translation/SPIRVLLVMTranslation.h"
1920
#include <llvm/Support/Error.h>
@@ -235,6 +236,31 @@ extern "C" JITResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
235236
return JITResult{FusedKernelInfo};
236237
}
237238

239+
extern "C" JITResult compileSYCL(const char *SYCLSource,
240+
View<IncludePair> IncludePairs,
241+
View<const char *> UserArgs,
242+
const char *DPCPPRoot) {
243+
std::unique_ptr<llvm::Module> Module =
244+
compileDeviceCode(SYCLSource, IncludePairs, UserArgs, DPCPPRoot);
245+
if (!Module) {
246+
return JITResult{"Device code compilation failed"};
247+
}
248+
249+
SYCLKernelInfo Kernel;
250+
auto Error = translation::KernelTranslator::translateKernel(
251+
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV);
252+
253+
auto *LLVMCtx = &Module->getContext();
254+
Module.reset();
255+
delete LLVMCtx;
256+
257+
if (Error) {
258+
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
259+
}
260+
261+
return JITResult{Kernel};
262+
}
263+
238264
extern "C" void resetJITConfiguration() { ConfigHelper::reset(); }
239265

240266
extern "C" void addToJITConfiguration(OptionStorage &&Opt) {
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//==---------------------- DeviceCompilation.cpp ---------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "DeviceCompilation.h"
10+
11+
#include <clang/Basic/Version.h>
12+
#include <clang/CodeGen/CodeGenAction.h>
13+
#include <clang/Driver/Compilation.h>
14+
#include <clang/Frontend/CompilerInstance.h>
15+
#include <clang/Tooling/CompilationDatabase.h>
16+
#include <clang/Tooling/Tooling.h>
17+
18+
#include <llvm/IR/Module.h>
19+
20+
namespace {
21+
using namespace clang;
22+
using namespace clang::tooling;
23+
using namespace clang::driver;
24+
25+
struct GetLLVMModuleAction : public ToolAction {
26+
// Code adapted from `FrontendActionFactory::runInvocation`.
27+
bool runInvocation(std::shared_ptr<CompilerInvocation> Invocation,
28+
FileManager *Files,
29+
std::shared_ptr<PCHContainerOperations> PCHContainerOps,
30+
DiagnosticConsumer *DiagConsumer) override {
31+
assert(!Module && "Action should only be invoked on a single file");
32+
33+
// Create a compiler instance to handle the actual work.
34+
CompilerInstance Compiler(std::move(PCHContainerOps));
35+
Compiler.setInvocation(std::move(Invocation));
36+
Compiler.setFileManager(Files);
37+
38+
// Create the compiler's actual diagnostics engine.
39+
Compiler.createDiagnostics(DiagConsumer, /*ShouldOwnClient=*/false);
40+
if (!Compiler.hasDiagnostics()) {
41+
return false;
42+
}
43+
44+
Compiler.createSourceManager(*Files);
45+
46+
// Ignore `Compiler.getFrontendOpts().ProgramAction` (would be `EmitBC`) and
47+
// create/execute an `EmitLLVMOnlyAction` (= codegen to LLVM module without
48+
// emitting anything) instead.
49+
EmitLLVMOnlyAction ELOA;
50+
const bool Success = Compiler.ExecuteAction(ELOA);
51+
Files->clearStatCache();
52+
if (!Success) {
53+
return false;
54+
}
55+
56+
// Take the module and its context to extend the objects' lifetime.
57+
Module = ELOA.takeModule();
58+
ELOA.takeLLVMContext();
59+
60+
return true;
61+
}
62+
63+
std::unique_ptr<llvm::Module> Module;
64+
};
65+
66+
} // anonymous namespace
67+
68+
std::unique_ptr<llvm::Module> jit_compiler::compileDeviceCode(
69+
const char *SYCLSource, View<IncludePair> IncludePairs,
70+
View<const char *> UserArgs, const char *DPCPPRoot) {
71+
72+
SmallVector<std::string> CommandLine = {"-fsycl-device-only"};
73+
// TODO: Allow instrumentation again when device library linking is
74+
// implemented.
75+
CommandLine.push_back("-fno-sycl-instrument-device-code");
76+
CommandLine.append(UserArgs.begin(), UserArgs.end());
77+
clang::tooling::FixedCompilationDatabase DB{"./", CommandLine};
78+
79+
constexpr auto SourcePath = "rtc.cpp";
80+
clang::tooling::ClangTool Tool{DB, {SourcePath}};
81+
82+
// Set up in-memory filesystem.
83+
Tool.mapVirtualFile(SourcePath, SYCLSource);
84+
for (const auto &IP : IncludePairs) {
85+
Tool.mapVirtualFile(IP.Path, IP.Contents);
86+
}
87+
88+
// Reset argument adjusters to drop the `-fsyntax-only` flag which is added by
89+
// default by this API.
90+
Tool.clearArgumentsAdjusters();
91+
// Then, modify argv[0] and set the resource directory so that the driver
92+
// picks up the correct SYCL environment.
93+
Tool.appendArgumentsAdjuster(
94+
[&DPCPPRoot](const CommandLineArguments &Args,
95+
StringRef Filename) -> CommandLineArguments {
96+
(void)Filename;
97+
CommandLineArguments NewArgs = Args;
98+
NewArgs[0] = (Twine(DPCPPRoot) + "/bin/clang++").str();
99+
NewArgs.push_back((Twine("-resource-dir=") + DPCPPRoot + "/lib/clang/" +
100+
Twine(CLANG_VERSION_MAJOR))
101+
.str());
102+
return NewArgs;
103+
});
104+
105+
GetLLVMModuleAction Action;
106+
if (!Tool.run(&Action)) {
107+
return std::move(Action.Module);
108+
}
109+
110+
return {};
111+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//==---- DeviceCompilation.h - Compile SYCL device code with libtooling ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SYCL_JIT_COMPILER_RTC_DEVICE_COMPILATION_H
10+
#define SYCL_JIT_COMPILER_RTC_DEVICE_COMPILATION_H
11+
12+
#include "Kernel.h"
13+
#include "View.h"
14+
15+
#include <memory>
16+
17+
namespace llvm {
18+
class Module;
19+
} // namespace llvm
20+
21+
namespace jit_compiler {
22+
23+
std::unique_ptr<llvm::Module> compileDeviceCode(const char *SYCLSource,
24+
View<IncludePair> IncludePairs,
25+
View<const char *> UserArgs,
26+
const char *DPCPPRoot);
27+
28+
} // namespace jit_compiler
29+
30+
#endif // SYCL_JIT_COMPILER_RTC_DEVICE_COMPILATION_H

sycl-jit/jit-compiler/lib/translation/SPIRVLLVMTranslation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ SPIRV::TranslatorOpts &SPIRVLLVMTranslator::translatorOpts() {
4141
// there's currently no obvious way to iterate the
4242
// array of extensions in KernelInfo.
4343
TransOpt.enableAllExtensions();
44+
// TODO: Remove this workaround.
45+
TransOpt.setAllowedToUseExtension(
46+
SPIRV::ExtensionID::SPV_KHR_untyped_pointers, false);
4447
TransOpt.setDesiredBIsRepresentation(
4548
SPIRV::BIsRepresentation::SPIRVFriendlyIR);
4649
// TODO: We need to take care of specialization constants, either by

sycl/source/detail/jit_compiler.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
#include <sycl/detail/ur.hpp>
1919
#include <sycl/kernel_bundle.hpp>
2020

21+
#include <dlfcn.h>
22+
#include <link.h>
23+
2124
namespace sycl {
2225
inline namespace _V1 {
2326
namespace detail {
@@ -74,6 +77,31 @@ jit_compiler::jit_compiler() {
7477
return false;
7578
}
7679

80+
this->CompileSYCLHandle = reinterpret_cast<CompileSYCLFuncT>(
81+
sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr, "compileSYCL"));
82+
if (!this->CompileSYCLHandle) {
83+
printPerformanceWarning(
84+
"Cannot resolve JIT library function entry point");
85+
return false;
86+
}
87+
88+
// TODO: Move this query to a more appropriate location (e.g. add
89+
// `sycl::detail::ur::getOsLibraryPath`), and handle non-POSIX OSs. For now,
90+
// it should be fine here because the JIT is not built on Windows.
91+
link_map *Map = nullptr;
92+
if (dlinfo(LibraryPtr, RTLD_DI_LINKMAP, &Map) == 0) {
93+
std::string LoadedLibraryPath = Map->l_name;
94+
std::string JITLibraryPathSuffix = "/lib/" + JITLibraryName;
95+
auto Pos = LoadedLibraryPath.rfind(JITLibraryPathSuffix);
96+
if (Pos != std::string::npos) {
97+
this->DPCPPRoot = LoadedLibraryPath.substr(0, Pos);
98+
}
99+
}
100+
if (this->DPCPPRoot.empty()) {
101+
printPerformanceWarning("Cannot determine JIT library location");
102+
return false;
103+
}
104+
77105
return true;
78106
};
79107
Available = checkJITLibrary();
@@ -1143,6 +1171,45 @@ std::vector<uint8_t> jit_compiler::encodeReqdWorkGroupSize(
11431171
return Encoded;
11441172
}
11451173

1174+
std::vector<uint8_t> jit_compiler::compileSYCL(
1175+
const std::string &SYCLSource,
1176+
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
1177+
const std::vector<std::string> &UserArgs, std::string *LogPtr,
1178+
const std::vector<std::string> &RegisteredKernelNames) {
1179+
1180+
// TODO: Handle situation.
1181+
assert(RegisteredKernelNames.empty() &&
1182+
"Instantiation of kernel templates NYI");
1183+
1184+
std::vector<::jit_compiler::IncludePair> IncludePairsView;
1185+
IncludePairsView.reserve(IncludePairs.size());
1186+
std::transform(IncludePairs.begin(), IncludePairs.end(),
1187+
std::back_inserter(IncludePairsView), [](const auto &Pair) {
1188+
return ::jit_compiler::IncludePair{Pair.first.c_str(),
1189+
Pair.second.c_str()};
1190+
});
1191+
std::vector<const char *> UserArgsView;
1192+
UserArgsView.reserve(UserArgs.size());
1193+
std::transform(UserArgs.begin(), UserArgs.end(),
1194+
std::back_inserter(UserArgsView),
1195+
[](const auto &Arg) { return Arg.c_str(); });
1196+
1197+
auto Result = CompileSYCLHandle(SYCLSource.c_str(), IncludePairsView,
1198+
UserArgsView, DPCPPRoot.c_str());
1199+
1200+
if (Result.failed()) {
1201+
throw sycl::exception(sycl::errc::build, Result.getErrorMessage());
1202+
}
1203+
1204+
// TODO: We currently don't have a meaningful build log.
1205+
(void)LogPtr;
1206+
1207+
const auto &BI = Result.getKernelInfo().BinaryInfo;
1208+
assert(BI.Format == ::jit_compiler::BinaryFormat::SPIRV);
1209+
std::vector<uint8_t> SPV(BI.BinaryStart, BI.BinaryStart + BI.BinarySize);
1210+
return SPV;
1211+
}
1212+
11461213
} // namespace detail
11471214
} // namespace _V1
11481215
} // namespace sycl

sycl/source/detail/jit_compiler.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ class jit_compiler {
4444
const std::string &KernelName,
4545
const std::vector<unsigned char> &SpecConstBlob);
4646

47+
std::vector<uint8_t> compileSYCL(
48+
const std::string &SYCLSource,
49+
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
50+
const std::vector<std::string> &UserArgs, std::string *LogPtr,
51+
const std::vector<std::string> &RegisteredKernelNames);
52+
4753
bool isAvailable() { return Available; }
4854

4955
static jit_compiler &get_instance() {
@@ -80,12 +86,15 @@ class jit_compiler {
8086
using FuseKernelsFuncT = decltype(::jit_compiler::fuseKernels) *;
8187
using MaterializeSpecConstFuncT =
8288
decltype(::jit_compiler::materializeSpecConstants) *;
89+
using CompileSYCLFuncT = decltype(::jit_compiler::compileSYCL) *;
8390
using ResetConfigFuncT = decltype(::jit_compiler::resetJITConfiguration) *;
8491
using AddToConfigFuncT = decltype(::jit_compiler::addToJITConfiguration) *;
8592
FuseKernelsFuncT FuseKernelsHandle = nullptr;
8693
MaterializeSpecConstFuncT MaterializeSpecConstHandle = nullptr;
94+
CompileSYCLFuncT CompileSYCLHandle = nullptr;
8795
ResetConfigFuncT ResetConfigHandle = nullptr;
8896
AddToConfigFuncT AddToConfigHandle = nullptr;
97+
std::string DPCPPRoot;
8998
#endif // SYCL_EXT_JIT_ENABLE
9099
};
91100

0 commit comments

Comments
 (0)