Skip to content

Commit db562cb

Browse files
committed
Return sycl_device_binaries from JIT library
Signed-off-by: Julian Oppermann <[email protected]>
1 parent b1eeac3 commit db562cb

File tree

5 files changed

+61
-69
lines changed

5 files changed

+61
-69
lines changed

sycl/source/detail/jit_compiler.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,7 @@ sycl_device_binaries jit_compiler::createPIDeviceBinary(
11191119
return JITDeviceBinaries.back().getPIDeviceStruct();
11201120
}
11211121

1122-
const RTDeviceBinaryImage &jit_compiler::createDeviceBinaryImage(
1122+
sycl_device_binaries jit_compiler::createDeviceBinaryImage(
11231123
const ::jit_compiler::RTCBundleInfo &BundleInfo) {
11241124
DeviceBinaryContainer Binary;
11251125
for (const auto &Symbol : BundleInfo.SymbolTable) {
@@ -1153,13 +1153,7 @@ const RTDeviceBinaryImage &jit_compiler::createDeviceBinaryImage(
11531153
: __SYCL_DEVICE_BINARY_TARGET_SPIRV32,
11541154
SYCL_DEVICE_BINARY_TYPE_SPIRV);
11551155
JITDeviceBinaries.push_back(std::move(Collection));
1156-
// TODO: If we want to handle multiple device binary images, we should instead
1157-
// return `sycl_device_binaries`, to be passed to
1158-
// `program_manager::addImages`. The program manager then creates and
1159-
// owns the `RTDeviceBinaryImage` instances.
1160-
RTCDeviceBinaryImages.emplace_back(
1161-
&JITDeviceBinaries.back().getPIDeviceStruct()->DeviceBinaries[0]);
1162-
return RTCDeviceBinaryImages.back();
1156+
return JITDeviceBinaries.back().getPIDeviceStruct();
11631157
}
11641158

11651159
std::vector<uint8_t> jit_compiler::encodeArgUsageMask(
@@ -1210,7 +1204,7 @@ std::vector<uint8_t> jit_compiler::encodeReqdWorkGroupSize(
12101204
return Encoded;
12111205
}
12121206

1213-
const RTDeviceBinaryImage &jit_compiler::compileSYCL(
1207+
sycl_device_binaries jit_compiler::compileSYCL(
12141208
const std::string &Id, const std::string &SYCLSource,
12151209
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
12161210
const std::vector<std::string> &UserArgs, std::string *LogPtr,

sycl/source/detail/jit_compiler.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class jit_compiler {
4444
const std::string &KernelName,
4545
const std::vector<unsigned char> &SpecConstBlob);
4646

47-
const RTDeviceBinaryImage &compileSYCL(
47+
sycl_device_binaries compileSYCL(
4848
const std::string &Id, const std::string &SYCLSource,
4949
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
5050
const std::vector<std::string> &UserArgs, std::string *LogPtr,
@@ -69,7 +69,7 @@ class jit_compiler {
6969
createPIDeviceBinary(const ::jit_compiler::SYCLKernelInfo &FusedKernelInfo,
7070
::jit_compiler::BinaryFormat Format);
7171

72-
const RTDeviceBinaryImage &
72+
sycl_device_binaries
7373
createDeviceBinaryImage(const ::jit_compiler::RTCBundleInfo &BundleInfo);
7474

7575
std::vector<uint8_t>
@@ -84,9 +84,6 @@ class jit_compiler {
8484
// Manages the lifetime of the UR structs for device binaries.
8585
std::vector<DeviceBinariesCollection> JITDeviceBinaries;
8686

87-
// Manages the lifetime of the runtime wrappers for device binary images.
88-
std::vector<RTDeviceBinaryImage> RTCDeviceBinaryImages;
89-
9087
#if SYCL_EXT_JIT_ENABLE
9188
// Handles to the entry points of the lazily loaded JIT library.
9289
using FuseKernelsFuncT = decltype(::jit_compiler::fuseKernels) *;

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -468,59 +468,60 @@ class kernel_bundle_impl {
468468
}
469469

470470
if (!FetchedFromCache) {
471-
if (Language == syclex::source_language::sycl_jit) {
472-
const auto &SourceStr = std::get<std::string>(this->Source);
473-
const auto &Img = syclex::detail::SYCL_JIT_to_SPIRV(
474-
SourceStr, IncludePairs, BuildOptions, LogPtr,
475-
RegisteredKernelNames);
476-
UrProgram = ProgramManager::getInstance().createURProgram(Img, MContext,
477-
MDevices);
478-
} else {
479-
const auto spirv = [&]() -> std::vector<uint8_t> {
480-
if (Language == syclex::source_language::opencl) {
481-
// if successful, the log is empty. if failed, throws an error with
482-
// the compilation log.
483-
std::vector<uint32_t> IPVersionVec(Devices.size());
484-
std::transform(DeviceVec.begin(), DeviceVec.end(),
485-
IPVersionVec.begin(), [&](ur_device_handle_t d) {
486-
uint32_t ipVersion = 0;
487-
Adapter->call<UrApiKind::urDeviceGetInfo>(
488-
d, UR_DEVICE_INFO_IP_VERSION, sizeof(uint32_t),
489-
&ipVersion, nullptr);
490-
return ipVersion;
491-
});
492-
return syclex::detail::OpenCLC_to_SPIRV(*SourceStrPtr, IPVersionVec,
493-
BuildOptions, LogPtr);
494-
}
495-
if (Language == syclex::source_language::spirv) {
496-
const auto &SourceBytes =
497-
std::get<std::vector<std::byte>>(this->Source);
498-
std::vector<uint8_t> Result(SourceBytes.size());
499-
std::transform(SourceBytes.cbegin(), SourceBytes.cend(),
500-
Result.begin(),
501-
[](std::byte B) { return static_cast<uint8_t>(B); });
502-
return Result;
503-
}
504-
if (Language == syclex::source_language::sycl) {
505-
return syclex::detail::SYCL_to_SPIRV(*SourceStrPtr, IncludePairs,
506-
BuildOptions, LogPtr,
507-
RegisteredKernelNames);
508-
}
509-
throw sycl::exception(
510-
make_error_code(errc::invalid),
511-
"SYCL C++, OpenCL C and SPIR-V are the only supported "
512-
"languages at this time");
513-
}();
514-
515-
Adapter->call<UrApiKind::urProgramCreateWithIL>(
516-
ContextImpl->getHandleRef(), spirv.data(), spirv.size(), nullptr,
517-
&UrProgram);
518-
// program created by urProgramCreateWithIL is implicitly retained.
519-
if (UrProgram == nullptr)
520-
throw sycl::exception(
521-
sycl::make_error_code(errc::invalid),
522-
"urProgramCreateWithIL resulted in a null program handle.");
523-
}
471+
const auto spirv = [&]() -> std::vector<uint8_t> {
472+
if (Language == syclex::source_language::opencl) {
473+
// if successful, the log is empty. if failed, throws an error with
474+
// the compilation log.
475+
std::vector<uint32_t> IPVersionVec(Devices.size());
476+
std::transform(DeviceVec.begin(), DeviceVec.end(),
477+
IPVersionVec.begin(), [&](ur_device_handle_t d) {
478+
uint32_t ipVersion = 0;
479+
Adapter->call<UrApiKind::urDeviceGetInfo>(
480+
d, UR_DEVICE_INFO_IP_VERSION, sizeof(uint32_t),
481+
&ipVersion, nullptr);
482+
return ipVersion;
483+
});
484+
return syclex::detail::OpenCLC_to_SPIRV(*SourceStrPtr, IPVersionVec,
485+
BuildOptions, LogPtr);
486+
}
487+
if (Language == syclex::source_language::spirv) {
488+
const auto &SourceBytes =
489+
std::get<std::vector<std::byte>>(this->Source);
490+
std::vector<uint8_t> Result(SourceBytes.size());
491+
std::transform(SourceBytes.cbegin(), SourceBytes.cend(),
492+
Result.begin(),
493+
[](std::byte B) { return static_cast<uint8_t>(B); });
494+
return Result;
495+
}
496+
if (Language == syclex::source_language::sycl) {
497+
return syclex::detail::SYCL_to_SPIRV(*SourceStrPtr, IncludePairs,
498+
BuildOptions, LogPtr,
499+
RegisteredKernelNames);
500+
}
501+
if (Language == syclex::source_language::sycl_jit) {
502+
auto *Binaries = syclex::detail::SYCL_JIT_to_SPIRV(
503+
*SourceStrPtr, IncludePairs, BuildOptions, LogPtr,
504+
RegisteredKernelNames);
505+
assert(Binaries->NumDeviceBinaries == 1 &&
506+
"Device code splitting is not yet supported");
507+
return std::vector<uint8_t>(Binaries->DeviceBinaries->BinaryStart,
508+
Binaries->DeviceBinaries->BinaryEnd);
509+
}
510+
throw sycl::exception(
511+
make_error_code(errc::invalid),
512+
"SYCL C++, OpenCL C and SPIR-V are the only supported "
513+
"languages at this time");
514+
}();
515+
516+
Adapter->call<UrApiKind::urProgramCreateWithIL>(
517+
ContextImpl->getHandleRef(), spirv.data(), spirv.size(), nullptr,
518+
&UrProgram);
519+
// program created by urProgramCreateWithIL is implicitly retained.
520+
if (UrProgram == nullptr)
521+
throw sycl::exception(
522+
sycl::make_error_code(errc::invalid),
523+
"urProgramCreateWithIL resulted in a null program handle.");
524+
524525
} // if(!FetchedFromCache)
525526

526527
std::string XsFlags = extractXsFlags(BuildOptions);

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ bool SYCL_JIT_Compilation_Available() {
342342
#endif
343343
}
344344

345-
const sycl::detail::RTDeviceBinaryImage &SYCL_JIT_to_SPIRV(
345+
sycl_device_binaries SYCL_JIT_to_SPIRV(
346346
[[maybe_unused]] const std::string &SYCLSource,
347347
[[maybe_unused]] include_pairs_t IncludePairs,
348348
[[maybe_unused]] const std::vector<std::string> &UserArgs,

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ bool SYCL_Compilation_Available();
3535

3636
std::string userArgsAsString(const std::vector<std::string> &UserArguments);
3737

38-
const sycl::detail::RTDeviceBinaryImage &
38+
sycl_device_binaries
3939
SYCL_JIT_to_SPIRV(const std::string &Source, include_pairs_t IncludePairs,
4040
const std::vector<std::string> &UserArgs, std::string *LogPtr,
4141
const std::vector<std::string> &RegisteredKernelNames);

0 commit comments

Comments
 (0)