1212#include < detail/kernel_compiler/kernel_compiler_opencl.hpp>
1313#include < detail/kernel_compiler/kernel_compiler_sycl.hpp>
1414#include < detail/kernel_impl.hpp>
15+ #include < detail/persistent_device_code_cache.hpp>
1516#include < detail/program_manager/program_manager.hpp>
1617#include < sycl/backend_types.hpp>
1718#include < sycl/context.hpp>
@@ -396,6 +397,53 @@ class kernel_bundle_impl {
396397 return SS.str ();
397398 }
398399
400+ bool
401+ extKernelCompilerFetchFromCache (const std::vector<device> Devices,
402+ const std::vector<std::string> &BuildOptions,
403+ const std::string &SourceStr,
404+ ur_program_handle_t &UrProgram) {
405+ using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
406+ ContextImplPtr ContextImpl = getSyclObjImpl (MContext);
407+ const AdapterPtr &Adapter = ContextImpl->getAdapter ();
408+
409+ std::string UserArgs = syclex::detail::userArgsAsString (BuildOptions);
410+
411+ std::vector<ur_device_handle_t > DeviceHandles;
412+ std::transform (
413+ Devices.begin (), Devices.end (), std::back_inserter (DeviceHandles),
414+ [](const device &Dev) { return getSyclObjImpl (Dev)->getHandleRef (); });
415+
416+ std::vector<const uint8_t *> Binaries;
417+ std::vector<size_t > Lengths;
418+ std::vector<std::vector<std::vector<char >>> PersistentBinaries;
419+ for (size_t i = 0 ; i < Devices.size (); i++) {
420+ std::vector<std::vector<char >> BinProg =
421+ PersistentDeviceCodeCache::getCompiledKernelFromDisc (
422+ Devices[i], UserArgs, SourceStr);
423+
424+ // exit if any device binary is missing
425+ if (BinProg.empty ()) {
426+ return false ;
427+ }
428+ PersistentBinaries.push_back (BinProg);
429+
430+ Binaries.push_back ((uint8_t *)(PersistentBinaries[i][0 ].data ()));
431+ Lengths.push_back (PersistentBinaries[i][0 ].size ());
432+ }
433+
434+ ur_program_properties_t Properties = {};
435+ Properties.stype = UR_STRUCTURE_TYPE_PROGRAM_PROPERTIES;
436+ Properties.pNext = nullptr ;
437+ Properties.count = 0 ;
438+ Properties.pMetadatas = nullptr ;
439+
440+ Adapter->call <UrApiKind::urProgramCreateWithBinary>(
441+ ContextImpl->getHandleRef (), DeviceHandles.size (), DeviceHandles.data (),
442+ Lengths.data (), Binaries.data (), &Properties, &UrProgram);
443+
444+ return true ;
445+ }
446+
399447 std::shared_ptr<kernel_bundle_impl>
400448 build_from_source (const std::vector<device> Devices,
401449 const std::vector<std::string> &BuildOptions,
@@ -415,57 +463,68 @@ class kernel_bundle_impl {
415463 DeviceVec.push_back (Dev);
416464 }
417465
418- const auto spirv = [&]() -> std::vector<uint8_t > {
419- if (Language == syclex::source_language::opencl) {
420- // if successful, the log is empty. if failed, throws an error with the
421- // compilation log.
422- const auto &SourceStr = std::get<std::string>(this ->Source );
423- std::vector<uint32_t > IPVersionVec (Devices.size ());
424- std::transform (DeviceVec.begin (), DeviceVec.end (), IPVersionVec.begin (),
425- [&](ur_device_handle_t d) {
426- uint32_t ipVersion = 0 ;
427- Adapter->call <UrApiKind::urDeviceGetInfo>(
428- d, UR_DEVICE_INFO_IP_VERSION, sizeof (uint32_t ),
429- &ipVersion, nullptr );
430- return ipVersion;
431- });
432- return syclex::detail::OpenCLC_to_SPIRV (SourceStr, IPVersionVec,
433- BuildOptions, LogPtr);
434- }
435- if (Language == syclex::source_language::spirv) {
436- const auto &SourceBytes =
437- std::get<std::vector<std::byte>>(this ->Source );
438- std::vector<uint8_t > Result (SourceBytes.size ());
439- std::transform (SourceBytes.cbegin (), SourceBytes.cend (), Result.begin (),
440- [](std::byte B) { return static_cast <uint8_t >(B); });
441- return Result;
442- }
443- if (Language == syclex::source_language::sycl) {
444- const auto &SourceStr = std::get<std::string>(this ->Source );
445- return syclex::detail::SYCL_to_SPIRV (SourceStr, IncludePairs,
446- BuildOptions, LogPtr,
447- RegisteredKernelNames);
448- }
449- if (Language == syclex::source_language::sycl_jit) {
450- const auto &SourceStr = std::get<std::string>(this ->Source );
451- return syclex::detail::SYCL_JIT_to_SPIRV (SourceStr, IncludePairs,
452- BuildOptions, LogPtr,
453- RegisteredKernelNames);
454- }
455- throw sycl::exception (
456- make_error_code (errc::invalid),
457- " OpenCL C and SPIR-V are the only supported languages at this time" );
458- }();
459-
460466 ur_program_handle_t UrProgram = nullptr ;
461- Adapter->call <UrApiKind::urProgramCreateWithIL>(ContextImpl->getHandleRef (),
462- spirv.data (), spirv.size (),
463- nullptr , &UrProgram);
464- // program created by urProgramCreateWithIL is implicitly retained.
465- if (UrProgram == nullptr )
466- throw sycl::exception (
467- sycl::make_error_code (errc::invalid),
468- " urProgramCreateWithIL resulted in a null program handle." );
467+ // SourceStrPtr will be null when source is Spir-V bytes.
468+ const std::string *SourceStrPtr = std::get_if<std::string>(&this ->Source );
469+ bool FetchedFromCache = false ;
470+ if (PersistentDeviceCodeCache::isEnabled () && SourceStrPtr) {
471+ FetchedFromCache = extKernelCompilerFetchFromCache (
472+ Devices, BuildOptions, *SourceStrPtr, UrProgram);
473+ }
474+
475+ if (!FetchedFromCache) {
476+ const auto spirv = [&]() -> std::vector<uint8_t > {
477+ if (Language == syclex::source_language::opencl) {
478+ // if successful, the log is empty. if failed, throws an error with
479+ // the compilation log.
480+ std::vector<uint32_t > IPVersionVec (Devices.size ());
481+ std::transform (DeviceVec.begin (), DeviceVec.end (),
482+ IPVersionVec.begin (), [&](ur_device_handle_t d) {
483+ uint32_t ipVersion = 0 ;
484+ Adapter->call <UrApiKind::urDeviceGetInfo>(
485+ d, UR_DEVICE_INFO_IP_VERSION, sizeof (uint32_t ),
486+ &ipVersion, nullptr );
487+ return ipVersion;
488+ });
489+ return syclex::detail::OpenCLC_to_SPIRV (*SourceStrPtr, IPVersionVec,
490+ BuildOptions, LogPtr);
491+ }
492+ if (Language == syclex::source_language::spirv) {
493+ const auto &SourceBytes =
494+ std::get<std::vector<std::byte>>(this ->Source );
495+ std::vector<uint8_t > Result (SourceBytes.size ());
496+ std::transform (SourceBytes.cbegin (), SourceBytes.cend (),
497+ Result.begin (),
498+ [](std::byte B) { return static_cast <uint8_t >(B); });
499+ return Result;
500+ }
501+ if (Language == syclex::source_language::sycl) {
502+ return syclex::detail::SYCL_to_SPIRV (*SourceStrPtr, IncludePairs,
503+ BuildOptions, LogPtr,
504+ RegisteredKernelNames);
505+ }
506+ if (Language == syclex::source_language::sycl_jit) {
507+ const auto &SourceStr = std::get<std::string>(this ->Source );
508+ return syclex::detail::SYCL_JIT_to_SPIRV (SourceStr, IncludePairs,
509+ BuildOptions, LogPtr,
510+ RegisteredKernelNames);
511+ }
512+ throw sycl::exception (
513+ make_error_code (errc::invalid),
514+ " SYCL C++, OpenCL C and SPIR-V are the only supported "
515+ " languages at this time" );
516+ }();
517+
518+ Adapter->call <UrApiKind::urProgramCreateWithIL>(
519+ ContextImpl->getHandleRef (), spirv.data (), spirv.size (), nullptr ,
520+ &UrProgram);
521+ // program created by urProgramCreateWithIL is implicitly retained.
522+ if (UrProgram == nullptr )
523+ throw sycl::exception (
524+ sycl::make_error_code (errc::invalid),
525+ " urProgramCreateWithIL resulted in a null program handle." );
526+
527+ } // if(!FetchedFromCache)
469528
470529 std::string XsFlags = extractXsFlags (BuildOptions);
471530 auto Res = Adapter->call_nocheck <UrApiKind::urProgramBuildExp>(
@@ -501,6 +560,17 @@ class kernel_bundle_impl {
501560 nullptr , MContext, MDevices, bundle_state::executable, KernelIDs,
502561 UrProgram);
503562 device_image_plain DevImg{DevImgImpl};
563+
564+ // If caching enabled and kernel not fetched from cache, cache.
565+ if (PersistentDeviceCodeCache::isEnabled () && !FetchedFromCache &&
566+ SourceStrPtr) {
567+ for (const auto &Device : Devices) {
568+ PersistentDeviceCodeCache::putCompiledKernelToDisc (
569+ Device, syclex::detail::userArgsAsString (BuildOptions),
570+ *SourceStrPtr, UrProgram);
571+ }
572+ }
573+
504574 return std::make_shared<kernel_bundle_impl>(MContext, MDevices, DevImg,
505575 KernelNames, Language);
506576 }
0 commit comments