@@ -115,19 +115,19 @@ static ur_program_handle_t createSpirvProgram(const ContextImplPtr &Context,
115115}
116116
117117// TODO replace this with a new UR API function
118- static bool isDeviceBinaryTypeSupported (const context &C ,
118+ static bool isDeviceBinaryTypeSupported (const ContextImplPtr &ContextImpl ,
119119 ur::DeviceBinaryType Format) {
120120 // All formats except SYCL_DEVICE_BINARY_TYPE_SPIRV are supported.
121121 if (Format != SYCL_DEVICE_BINARY_TYPE_SPIRV)
122122 return true ;
123123
124- const backend ContextBackend = detail::getSyclObjImpl (C) ->getBackend ();
124+ const backend ContextBackend = ContextImpl ->getBackend ();
125125
126126 // The CUDA backend cannot use SPIR-V
127127 if (ContextBackend == backend::ext_oneapi_cuda)
128128 return false ;
129129
130- std::vector<device> Devices = C. get_devices ();
130+ const std::vector<device> & Devices = ContextImpl-> getDevices ();
131131
132132 // Program type is SPIR-V, so we need a device compiler to do JIT.
133133 for (const device &D : Devices) {
@@ -137,7 +137,8 @@ static bool isDeviceBinaryTypeSupported(const context &C,
137137
138138 // OpenCL 2.1 and greater require clCreateProgramWithIL
139139 if (ContextBackend == backend::opencl) {
140- std::string ver = C.get_platform ().get_info <info::platform::version>();
140+ std::string ver = ContextImpl->get_info <info::context::platform>()
141+ .get_info <info::platform::version>();
141142 if (ver.find (" OpenCL 1.0" ) == std::string::npos &&
142143 ver.find (" OpenCL 1.1" ) == std::string::npos &&
143144 ver.find (" OpenCL 1.2" ) == std::string::npos &&
@@ -187,16 +188,15 @@ static bool isDeviceBinaryTypeSupported(const context &C,
187188
188189ur_program_handle_t
189190ProgramManager::createURProgram (const RTDeviceBinaryImage &Img,
190- const context &Context ,
191+ const ContextImplPtr &ContextImpl ,
191192 const std::vector<device> &Devices) {
192193 if constexpr (DbgProgMgr > 0 ) {
193194 std::vector<ur_device_handle_t > URDevices;
194195 std::transform (
195196 Devices.begin (), Devices.end (), std::back_inserter (URDevices),
196197 [](const device &Dev) { return getSyclObjImpl (Dev)->getHandleRef (); });
197198 std::cerr << " >>> ProgramManager::createPIProgram(" << &Img << " , "
198- << getSyclObjImpl (Context).get () << " , " << VecToString (URDevices)
199- << " )\n " ;
199+ << ContextImpl.get () << " , " << VecToString (URDevices) << " )\n " ;
200200 }
201201 const sycl_device_binary_struct &RawImg = Img.getRawData ();
202202
@@ -224,7 +224,7 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
224224 // sycl::detail::pi::PiDeviceBinaryType Format = Img->Format;
225225 // assert(Format != SYCL_DEVICE_BINARY_TYPE_NONE && "Image format not set");
226226
227- if (!isDeviceBinaryTypeSupported (Context , Format))
227+ if (!isDeviceBinaryTypeSupported (ContextImpl , Format))
228228 throw sycl::exception (
229229 sycl::errc::feature_not_supported,
230230 " SPIR-V online compilation is not supported in this context" );
@@ -233,23 +233,22 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
233233 const auto &ProgMetadata = Img.getProgramMetadataUR ();
234234
235235 // Load the image
236- const ContextImplPtr &Ctx = getSyclObjImpl (Context);
237236 std::vector<const uint8_t *> Binaries (
238237 Devices.size (), const_cast <uint8_t *>(RawImg.BinaryStart ));
239238 std::vector<size_t > Lengths (Devices.size (), ImgSize);
240239 ur_program_handle_t Res =
241240 Format == SYCL_DEVICE_BINARY_TYPE_SPIRV
242- ? createSpirvProgram (Ctx , RawImg.BinaryStart , ImgSize)
243- : createBinaryProgram (Ctx , Devices, Binaries. data (), Lengths .data (),
244- ProgMetadata);
241+ ? createSpirvProgram (ContextImpl , RawImg.BinaryStart , ImgSize)
242+ : createBinaryProgram (ContextImpl , Devices, Binaries.data (),
243+ Lengths. data (), ProgMetadata);
245244
246245 {
247246 std::lock_guard<std::mutex> Lock (MNativeProgramsMutex);
248247 // associate the UR program with the image it was created for
249- NativePrograms.insert ({Res, {Ctx , &Img}});
248+ NativePrograms.insert ({Res, {ContextImpl , &Img}});
250249 }
251250
252- Ctx ->addDeviceGlobalInitializer (Res, Devices, &Img);
251+ ContextImpl ->addDeviceGlobalInitializer (Res, Devices, &Img);
253252
254253 if constexpr (DbgProgMgr > 1 )
255254 std::cerr << " created program: " << Res
@@ -518,7 +517,7 @@ static void applyOptionsFromEnvironment(std::string &CompileOpts,
518517std::pair<ur_program_handle_t , bool > ProgramManager::getOrCreateURProgram (
519518 const RTDeviceBinaryImage &MainImg,
520519 const std::vector<const RTDeviceBinaryImage *> &AllImages,
521- const context &Context , const std::vector<device> &Devices,
520+ const ContextImplPtr &ContextImpl , const std::vector<device> &Devices,
522521 const std::string &CompileAndLinkOptions, SerializedObj SpecConsts) {
523522 ur_program_handle_t NativePrg;
524523
@@ -540,11 +539,10 @@ std::pair<ur_program_handle_t, bool> ProgramManager::getOrCreateURProgram(
540539 ProgMetadataVector.insert (ProgMetadataVector.end (),
541540 ImgProgMetadata.begin (), ImgProgMetadata.end ());
542541 }
543- NativePrg =
544- createBinaryProgram (getSyclObjImpl (Context), Devices, BinPtrs.data (),
545- Lengths.data (), ProgMetadataVector);
542+ NativePrg = createBinaryProgram (ContextImpl, Devices, BinPtrs.data (),
543+ Lengths.data (), ProgMetadataVector);
546544 } else {
547- NativePrg = createURProgram (MainImg, Context , Devices);
545+ NativePrg = createURProgram (MainImg, ContextImpl , Devices);
548546 }
549547 return {NativePrg, Binaries.size ()};
550548}
@@ -857,10 +855,10 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
857855 sizeof (ur_bool_t ), &MustBuildOnSubdevice, nullptr );
858856 }
859857
860- auto Context = createSyclObjFromImpl<context>(ContextImpl);
861858 auto Device = createSyclObjFromImpl<device>(
862859 MustBuildOnSubdevice == true ? DeviceImpl : RootDevImpl);
863- const RTDeviceBinaryImage &Img = getDeviceImage (KernelName, Context, Device);
860+ const RTDeviceBinaryImage &Img =
861+ getDeviceImage (KernelName, ContextImpl, Device);
864862
865863 // Check that device supports all aspects used by the kernel
866864 if (auto exception = checkDevSupportDeviceRequirements (Device, Img, NDRDesc))
@@ -879,19 +877,19 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
879877 std::copy (DeviceImagesToLink.begin (), DeviceImagesToLink.end (),
880878 std::back_inserter (AllImages));
881879
882- return getBuiltURProgram (std::move (AllImages), Context, {std::move (Device)});
880+ return getBuiltURProgram (std::move (AllImages), ContextImpl,
881+ {std::move (Device)});
883882}
884883
885884ur_program_handle_t ProgramManager::getBuiltURProgram (
886- const BinImgWithDeps &ImgWithDeps, const context &Context ,
885+ const BinImgWithDeps &ImgWithDeps, const ContextImplPtr &ContextImpl ,
887886 const std::vector<device> &Devs, const DevImgPlainWithDeps *DevImgWithDeps,
888887 const SerializedObj &SpecConsts) {
889888 std::string CompileOpts;
890889 std::string LinkOpts;
891890 applyOptionsFromEnvironment (CompileOpts, LinkOpts);
892- auto BuildF = [this , &ImgWithDeps, &DevImgWithDeps, &Context , &Devs,
891+ auto BuildF = [this , &ImgWithDeps, &DevImgWithDeps, &ContextImpl , &Devs,
893892 &CompileOpts, &LinkOpts, &SpecConsts] {
894- const ContextImplPtr &ContextImpl = getSyclObjImpl (Context);
895893 const AdapterPtr &Adapter = ContextImpl->getAdapter ();
896894 const RTDeviceBinaryImage &MainImg = *ImgWithDeps.getMain ();
897895 applyOptionsFromImage (CompileOpts, LinkOpts, MainImg, Devs, Adapter);
@@ -900,7 +898,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
900898 appendLinkEnvironmentVariablesThatAppend (LinkOpts);
901899
902900 auto [NativePrg, DeviceCodeWasInCache] =
903- getOrCreateURProgram (MainImg, ImgWithDeps.getAll (), Context , Devs,
901+ getOrCreateURProgram (MainImg, ImgWithDeps.getAll (), ContextImpl , Devs,
904902 CompileOpts + LinkOpts, SpecConsts);
905903
906904 if (!DeviceCodeWasInCache && MainImg.supportsSpecConstants ()) {
@@ -940,7 +938,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
940938 if (UseDeviceLibs)
941939 DeviceLibReqMask |= getDeviceLibReqMask (*BinImg);
942940
943- ur_program_handle_t NativePrg = createURProgram (*BinImg, Context, Devs);
941+ ur_program_handle_t NativePrg =
942+ createURProgram (*BinImg, ContextImpl, Devs);
944943
945944 if (BinImg->supportsSpecConstants ()) {
946945 enableITTAnnotationsIfNeeded (NativePrg, Adapter);
@@ -1005,7 +1004,6 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
10051004 auto CacheKey =
10061005 std::make_pair (std::make_pair (SpecConsts, ImgId), URDevicesSet);
10071006
1008- const ContextImplPtr &ContextImpl = getSyclObjImpl (Context);
10091007 KernelProgramCache &Cache = ContextImpl->getKernelProgramCache ();
10101008 auto GetCachedBuildF = [&Cache, &CacheKey]() {
10111009 return Cache.getOrInsertProgram (CacheKey);
@@ -1480,7 +1478,8 @@ sycl_device_binary getRawImg(RTDeviceBinaryImage *Img) {
14801478template <typename StorageKey>
14811479RTDeviceBinaryImage *getBinImageFromMultiMap (
14821480 const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
1483- const StorageKey &Key, const context &Context, const device &Device) {
1481+ const StorageKey &Key, const ContextImplPtr &ContextImpl,
1482+ const device &Device) {
14841483 auto [ItBegin, ItEnd] = ImagesSet.equal_range (Key);
14851484 if (ItBegin == ItEnd)
14861485 return nullptr ;
@@ -1510,19 +1509,20 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
15101509 uint32_t ImgInd = 0 ;
15111510 // Ask the native runtime under the given context to choose the device image
15121511 // it prefers.
1513- getSyclObjImpl (Context) ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1512+ ContextImpl ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
15141513 getSyclObjImpl (Device)->getHandleRef (), UrBinaries.data (),
15151514 UrBinaries.size (), &ImgInd);
15161515 return DeviceFilteredImgs[ImgInd];
15171516}
15181517
15191518RTDeviceBinaryImage &
15201519ProgramManager::getDeviceImage (const std::string &KernelName,
1521- const context &Context, const device &Device) {
1520+ const ContextImplPtr &ContextImpl,
1521+ const device &Device) {
15221522 if constexpr (DbgProgMgr > 0 ) {
15231523 std::cerr << " >>> ProgramManager::getDeviceImage(\" " << KernelName << " \" , "
1524- << getSyclObjImpl (Context) .get () << " , "
1525- << getSyclObjImpl (Device). get () << " )\n " ;
1524+ << ContextImpl .get () << " , " << getSyclObjImpl (Device). get ()
1525+ << " )\n " ;
15261526
15271527 std::cerr << " available device images:\n " ;
15281528 debugPrintBinaryImages ();
@@ -1532,7 +1532,7 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15321532 assert (m_SpvFileImage);
15331533 return getDeviceImage (
15341534 std::unordered_set<RTDeviceBinaryImage *>({m_SpvFileImage.get ()}),
1535- Context , Device);
1535+ ContextImpl , Device);
15361536 }
15371537
15381538 RTDeviceBinaryImage *Img = nullptr ;
@@ -1541,9 +1541,9 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15411541 if (auto KernelId = m_KernelName2KernelIDs.find (KernelName);
15421542 KernelId != m_KernelName2KernelIDs.end ()) {
15431543 Img = getBinImageFromMultiMap (m_KernelIDs2BinImage, KernelId->second ,
1544- Context , Device);
1544+ ContextImpl , Device);
15451545 } else {
1546- Img = getBinImageFromMultiMap (m_ServiceKernels, KernelName, Context ,
1546+ Img = getBinImageFromMultiMap (m_ServiceKernels, KernelName, ContextImpl ,
15471547 Device);
15481548 }
15491549 }
@@ -1565,13 +1565,13 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15651565
15661566RTDeviceBinaryImage &ProgramManager::getDeviceImage (
15671567 const std::unordered_set<RTDeviceBinaryImage *> &ImageSet,
1568- const context &Context , const device &Device) {
1568+ const ContextImplPtr &ContextImpl , const device &Device) {
15691569 assert (ImageSet.size () > 0 );
15701570
15711571 if constexpr (DbgProgMgr > 0 ) {
15721572 std::cerr << " >>> ProgramManager::getDeviceImage(Custom SPV file "
1573- << getSyclObjImpl (Context) .get () << " , "
1574- << getSyclObjImpl (Device). get () << " )\n " ;
1573+ << ContextImpl .get () << " , " << getSyclObjImpl (Device). get ()
1574+ << " )\n " ;
15751575
15761576 std::cerr << " available device images:\n " ;
15771577 debugPrintBinaryImages ();
@@ -1593,7 +1593,7 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(
15931593 getUrDeviceTarget (RawImgs[BinaryCount]->DeviceTargetSpec );
15941594 }
15951595
1596- getSyclObjImpl (Context) ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1596+ ContextImpl ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
15971597 getSyclObjImpl (Device)->getHandleRef (), UrBinaries.data (),
15981598 UrBinaries.size (), &ImgInd);
15991599
@@ -2888,8 +2888,9 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
28882888 const AdapterPtr &Adapter =
28892889 getSyclObjImpl (InputImpl->get_context ())->getAdapter ();
28902890
2891- ur_program_handle_t Prog = createURProgram (*InputImpl->get_bin_image_ref (),
2892- InputImpl->get_context (), Devs);
2891+ ur_program_handle_t Prog =
2892+ createURProgram (*InputImpl->get_bin_image_ref (),
2893+ getSyclObjImpl (InputImpl->get_context ()), Devs);
28932894
28942895 if (InputImpl->get_bin_image_ref ()->supportsSpecConstants ())
28952896 setSpecializationConstants (InputImpl, Prog, Adapter);
@@ -3097,7 +3098,8 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
30973098 const std::shared_ptr<device_image_impl> &MainInputImpl =
30983099 getSyclObjImpl (DevImgWithDeps.getMain ());
30993100
3100- const context Context = MainInputImpl->get_context ();
3101+ const context &Context = MainInputImpl->get_context ();
3102+ const ContextImplPtr &ContextImpl = detail::getSyclObjImpl (Context);
31013103
31023104 std::vector<const RTDeviceBinaryImage *> BinImgs;
31033105 BinImgs.reserve (DevImgWithDeps.size ());
@@ -3138,7 +3140,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31383140 auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge (RTCInfoPtrs);
31393141
31403142 ur_program_handle_t ResProgram = getBuiltURProgram (
3141- std::move (BinImgs), Context , Devs, &DevImgWithDeps, SpecConstBlob);
3143+ std::move (BinImgs), ContextImpl , Devs, &DevImgWithDeps, SpecConstBlob);
31423144
31433145 DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
31443146 MainInputImpl->get_bin_image_ref (), Context, Devs,
@@ -3259,7 +3261,8 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
32593261
32603262 if constexpr (DbgProgMgr > 0 )
32613263 std::cerr << " >>> Adding the kernel to the cache.\n " ;
3262- auto Program = createURProgram (Img, Context, {Device});
3264+ const ContextImplPtr &ContextImpl = detail::getSyclObjImpl (Context);
3265+ auto Program = createURProgram (Img, ContextImpl, {Device});
32633266 auto DeviceImpl = detail::getSyclObjImpl (Device);
32643267 auto &Adapter = DeviceImpl->getAdapter ();
32653268 UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
@@ -3274,8 +3277,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
32743277 std::vector<ur_program_handle_t > ExtraProgramsToLink;
32753278 std::vector<ur_device_handle_t > Devs = {DeviceImpl->getHandleRef ()};
32763279 auto BuildProgram =
3277- build (std::move (ProgramManaged), detail::getSyclObjImpl (Context),
3278- CompileOpts, LinkOpts, Devs,
3280+ build (std::move (ProgramManaged), ContextImpl, CompileOpts, LinkOpts, Devs,
32793281 /* For non SPIR-V devices DeviceLibReqdMask is always 0*/ 0 ,
32803282 ExtraProgramsToLink);
32813283 ur_kernel_handle_t UrKernel{nullptr };
0 commit comments