@@ -115,19 +115,19 @@ static ur_program_handle_t createSpirvProgram(const ContextImplPtr &Context,
115115}
116116
117117// TODO replace this with a new UR API function
118- static bool isDeviceBinaryTypeSupported (const context &C ,
118+ static bool isDeviceBinaryTypeSupported (const ContextImplPtr &ContextImpl ,
119119 ur::DeviceBinaryType Format) {
120120 // All formats except SYCL_DEVICE_BINARY_TYPE_SPIRV are supported.
121121 if (Format != SYCL_DEVICE_BINARY_TYPE_SPIRV)
122122 return true ;
123123
124- const backend ContextBackend = detail::getSyclObjImpl (C) ->getBackend ();
124+ const backend ContextBackend = ContextImpl ->getBackend ();
125125
126126 // The CUDA backend cannot use SPIR-V
127127 if (ContextBackend == backend::ext_oneapi_cuda)
128128 return false ;
129129
130- std::vector<device> Devices = C. get_devices ();
130+ const std::vector<device> & Devices = ContextImpl-> getDevices ();
131131
132132 // Program type is SPIR-V, so we need a device compiler to do JIT.
133133 for (const device &D : Devices) {
@@ -137,7 +137,8 @@ static bool isDeviceBinaryTypeSupported(const context &C,
137137
138138 // OpenCL 2.1 and greater require clCreateProgramWithIL
139139 if (ContextBackend == backend::opencl) {
140- std::string ver = C.get_platform ().get_info <info::platform::version>();
140+ std::string ver = ContextImpl->get_info <info::context::platform>()
141+ .get_info <info::platform::version>();
141142 if (ver.find (" OpenCL 1.0" ) == std::string::npos &&
142143 ver.find (" OpenCL 1.1" ) == std::string::npos &&
143144 ver.find (" OpenCL 1.2" ) == std::string::npos &&
@@ -187,16 +188,15 @@ static bool isDeviceBinaryTypeSupported(const context &C,
187188
188189ur_program_handle_t
189190ProgramManager::createURProgram (const RTDeviceBinaryImage &Img,
190- const context &Context ,
191+ const ContextImplPtr &ContextImpl ,
191192 const std::vector<device> &Devices) {
192193 if constexpr (DbgProgMgr > 0 ) {
193194 std::vector<ur_device_handle_t > URDevices;
194195 std::transform (
195196 Devices.begin (), Devices.end (), std::back_inserter (URDevices),
196197 [](const device &Dev) { return getSyclObjImpl (Dev)->getHandleRef (); });
197198 std::cerr << " >>> ProgramManager::createPIProgram(" << &Img << " , "
198- << getSyclObjImpl (Context).get () << " , " << VecToString (URDevices)
199- << " )\n " ;
199+ << ContextImpl.get () << " , " << VecToString (URDevices) << " )\n " ;
200200 }
201201 const sycl_device_binary_struct &RawImg = Img.getRawData ();
202202
@@ -224,7 +224,7 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
224224 // sycl::detail::pi::PiDeviceBinaryType Format = Img->Format;
225225 // assert(Format != SYCL_DEVICE_BINARY_TYPE_NONE && "Image format not set");
226226
227- if (!isDeviceBinaryTypeSupported (Context , Format))
227+ if (!isDeviceBinaryTypeSupported (ContextImpl , Format))
228228 throw sycl::exception (
229229 sycl::errc::feature_not_supported,
230230 " SPIR-V online compilation is not supported in this context" );
@@ -233,23 +233,22 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
233233 const auto &ProgMetadata = Img.getProgramMetadataUR ();
234234
235235 // Load the image
236- const ContextImplPtr &Ctx = getSyclObjImpl (Context);
237236 std::vector<const uint8_t *> Binaries (
238237 Devices.size (), const_cast <uint8_t *>(RawImg.BinaryStart ));
239238 std::vector<size_t > Lengths (Devices.size (), ImgSize);
240239 ur_program_handle_t Res =
241240 Format == SYCL_DEVICE_BINARY_TYPE_SPIRV
242- ? createSpirvProgram (Ctx , RawImg.BinaryStart , ImgSize)
243- : createBinaryProgram (Ctx , Devices, Binaries. data (), Lengths .data (),
244- ProgMetadata);
241+ ? createSpirvProgram (ContextImpl , RawImg.BinaryStart , ImgSize)
242+ : createBinaryProgram (ContextImpl , Devices, Binaries.data (),
243+ Lengths. data (), ProgMetadata);
245244
246245 {
247246 std::lock_guard<std::mutex> Lock (MNativeProgramsMutex);
248247 // associate the UR program with the image it was created for
249- NativePrograms.insert ({Res, {Ctx , &Img}});
248+ NativePrograms.insert ({Res, {ContextImpl , &Img}});
250249 }
251250
252- Ctx ->addDeviceGlobalInitializer (Res, Devices, &Img);
251+ ContextImpl ->addDeviceGlobalInitializer (Res, Devices, &Img);
253252
254253 if constexpr (DbgProgMgr > 1 )
255254 std::cerr << " created program: " << Res
@@ -518,7 +517,7 @@ static void applyOptionsFromEnvironment(std::string &CompileOpts,
518517std::pair<ur_program_handle_t , bool > ProgramManager::getOrCreateURProgram (
519518 const RTDeviceBinaryImage &MainImg,
520519 const std::vector<const RTDeviceBinaryImage *> &AllImages,
521- const context &Context , const std::vector<device> &Devices,
520+ const ContextImplPtr &ContextImpl , const std::vector<device> &Devices,
522521 const std::string &CompileAndLinkOptions, SerializedObj SpecConsts) {
523522 ur_program_handle_t NativePrg;
524523
@@ -540,11 +539,10 @@ std::pair<ur_program_handle_t, bool> ProgramManager::getOrCreateURProgram(
540539 ProgMetadataVector.insert (ProgMetadataVector.end (),
541540 ImgProgMetadata.begin (), ImgProgMetadata.end ());
542541 }
543- NativePrg =
544- createBinaryProgram (getSyclObjImpl (Context), Devices, BinPtrs.data (),
545- Lengths.data (), ProgMetadataVector);
542+ NativePrg = createBinaryProgram (ContextImpl, Devices, BinPtrs.data (),
543+ Lengths.data (), ProgMetadataVector);
546544 } else {
547- NativePrg = createURProgram (MainImg, Context , Devices);
545+ NativePrg = createURProgram (MainImg, ContextImpl , Devices);
548546 }
549547 return {NativePrg, Binaries.size ()};
550548}
@@ -853,10 +851,10 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
853851 RootDevImpl->getHandleRef (), UR_DEVICE_INFO_BUILD_ON_SUBDEVICE,
854852 sizeof (ur_bool_t ), &MustBuildOnSubdevice, nullptr );
855853
856- auto Context = createSyclObjFromImpl<context>(ContextImpl);
857854 auto Device = createSyclObjFromImpl<device>(
858855 MustBuildOnSubdevice == true ? DeviceImpl : RootDevImpl);
859- const RTDeviceBinaryImage &Img = getDeviceImage (KernelName, Context, Device);
856+ const RTDeviceBinaryImage &Img =
857+ getDeviceImage (KernelName, ContextImpl, Device);
860858
861859 // Check that device supports all aspects used by the kernel
862860 if (auto exception = checkDevSupportDeviceRequirements (Device, Img, NDRDesc))
@@ -875,19 +873,19 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
875873 std::copy (DeviceImagesToLink.begin (), DeviceImagesToLink.end (),
876874 std::back_inserter (AllImages));
877875
878- return getBuiltURProgram (std::move (AllImages), Context, {std::move (Device)});
876+ return getBuiltURProgram (std::move (AllImages), ContextImpl,
877+ {std::move (Device)});
879878}
880879
881880ur_program_handle_t ProgramManager::getBuiltURProgram (
882- const BinImgWithDeps &ImgWithDeps, const context &Context ,
881+ const BinImgWithDeps &ImgWithDeps, const ContextImplPtr &ContextImpl ,
883882 const std::vector<device> &Devs, const DevImgPlainWithDeps *DevImgWithDeps,
884883 const SerializedObj &SpecConsts) {
885884 std::string CompileOpts;
886885 std::string LinkOpts;
887886 applyOptionsFromEnvironment (CompileOpts, LinkOpts);
888- auto BuildF = [this , &ImgWithDeps, &DevImgWithDeps, &Context , &Devs,
887+ auto BuildF = [this , &ImgWithDeps, &DevImgWithDeps, &ContextImpl , &Devs,
889888 &CompileOpts, &LinkOpts, &SpecConsts] {
890- const ContextImplPtr &ContextImpl = getSyclObjImpl (Context);
891889 const AdapterPtr &Adapter = ContextImpl->getAdapter ();
892890 const RTDeviceBinaryImage &MainImg = *ImgWithDeps.getMain ();
893891 applyOptionsFromImage (CompileOpts, LinkOpts, MainImg, Devs, Adapter);
@@ -896,7 +894,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
896894 appendLinkEnvironmentVariablesThatAppend (LinkOpts);
897895
898896 auto [NativePrg, DeviceCodeWasInCache] =
899- getOrCreateURProgram (MainImg, ImgWithDeps.getAll (), Context , Devs,
897+ getOrCreateURProgram (MainImg, ImgWithDeps.getAll (), ContextImpl , Devs,
900898 CompileOpts + LinkOpts, SpecConsts);
901899
902900 if (!DeviceCodeWasInCache && MainImg.supportsSpecConstants ()) {
@@ -936,7 +934,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
936934 if (UseDeviceLibs)
937935 DeviceLibReqMask |= getDeviceLibReqMask (*BinImg);
938936
939- ur_program_handle_t NativePrg = createURProgram (*BinImg, Context, Devs);
937+ ur_program_handle_t NativePrg =
938+ createURProgram (*BinImg, ContextImpl, Devs);
940939
941940 if (BinImg->supportsSpecConstants ()) {
942941 enableITTAnnotationsIfNeeded (NativePrg, Adapter);
@@ -1001,7 +1000,6 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
10011000 auto CacheKey =
10021001 std::make_pair (std::make_pair (SpecConsts, ImgId), URDevicesSet);
10031002
1004- const ContextImplPtr &ContextImpl = getSyclObjImpl (Context);
10051003 KernelProgramCache &Cache = ContextImpl->getKernelProgramCache ();
10061004 auto GetCachedBuildF = [&Cache, &CacheKey]() {
10071005 return Cache.getOrInsertProgram (CacheKey);
@@ -1476,7 +1474,8 @@ sycl_device_binary getRawImg(RTDeviceBinaryImage *Img) {
14761474template <typename StorageKey>
14771475RTDeviceBinaryImage *getBinImageFromMultiMap (
14781476 const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
1479- const StorageKey &Key, const context &Context, const device &Device) {
1477+ const StorageKey &Key, const ContextImplPtr &ContextImpl,
1478+ const device &Device) {
14801479 auto [ItBegin, ItEnd] = ImagesSet.equal_range (Key);
14811480 if (ItBegin == ItEnd)
14821481 return nullptr ;
@@ -1506,19 +1505,20 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
15061505 uint32_t ImgInd = 0 ;
15071506 // Ask the native runtime under the given context to choose the device image
15081507 // it prefers.
1509- getSyclObjImpl (Context) ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1508+ ContextImpl ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
15101509 getSyclObjImpl (Device)->getHandleRef (), UrBinaries.data (),
15111510 UrBinaries.size (), &ImgInd);
15121511 return DeviceFilteredImgs[ImgInd];
15131512}
15141513
15151514RTDeviceBinaryImage &
15161515ProgramManager::getDeviceImage (const std::string &KernelName,
1517- const context &Context, const device &Device) {
1516+ const ContextImplPtr &ContextImpl,
1517+ const device &Device) {
15181518 if constexpr (DbgProgMgr > 0 ) {
15191519 std::cerr << " >>> ProgramManager::getDeviceImage(\" " << KernelName << " \" , "
1520- << getSyclObjImpl (Context) .get () << " , "
1521- << getSyclObjImpl (Device). get () << " )\n " ;
1520+ << ContextImpl .get () << " , " << getSyclObjImpl (Device). get ()
1521+ << " )\n " ;
15221522
15231523 std::cerr << " available device images:\n " ;
15241524 debugPrintBinaryImages ();
@@ -1528,7 +1528,7 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15281528 assert (m_SpvFileImage);
15291529 return getDeviceImage (
15301530 std::unordered_set<RTDeviceBinaryImage *>({m_SpvFileImage.get ()}),
1531- Context , Device);
1531+ ContextImpl , Device);
15321532 }
15331533
15341534 RTDeviceBinaryImage *Img = nullptr ;
@@ -1537,9 +1537,9 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15371537 if (auto KernelId = m_KernelName2KernelIDs.find (KernelName);
15381538 KernelId != m_KernelName2KernelIDs.end ()) {
15391539 Img = getBinImageFromMultiMap (m_KernelIDs2BinImage, KernelId->second ,
1540- Context , Device);
1540+ ContextImpl , Device);
15411541 } else {
1542- Img = getBinImageFromMultiMap (m_ServiceKernels, KernelName, Context ,
1542+ Img = getBinImageFromMultiMap (m_ServiceKernels, KernelName, ContextImpl ,
15431543 Device);
15441544 }
15451545 }
@@ -1561,13 +1561,13 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15611561
15621562RTDeviceBinaryImage &ProgramManager::getDeviceImage (
15631563 const std::unordered_set<RTDeviceBinaryImage *> &ImageSet,
1564- const context &Context , const device &Device) {
1564+ const ContextImplPtr &ContextImpl , const device &Device) {
15651565 assert (ImageSet.size () > 0 );
15661566
15671567 if constexpr (DbgProgMgr > 0 ) {
15681568 std::cerr << " >>> ProgramManager::getDeviceImage(Custom SPV file "
1569- << getSyclObjImpl (Context) .get () << " , "
1570- << getSyclObjImpl (Device). get () << " )\n " ;
1569+ << ContextImpl .get () << " , " << getSyclObjImpl (Device). get ()
1570+ << " )\n " ;
15711571
15721572 std::cerr << " available device images:\n " ;
15731573 debugPrintBinaryImages ();
@@ -1589,7 +1589,7 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(
15891589 getUrDeviceTarget (RawImgs[BinaryCount]->DeviceTargetSpec );
15901590 }
15911591
1592- getSyclObjImpl (Context) ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1592+ ContextImpl ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
15931593 getSyclObjImpl (Device)->getHandleRef (), UrBinaries.data (),
15941594 UrBinaries.size (), &ImgInd);
15951595
@@ -2845,8 +2845,9 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
28452845 const AdapterPtr &Adapter =
28462846 getSyclObjImpl (InputImpl->get_context ())->getAdapter ();
28472847
2848- ur_program_handle_t Prog = createURProgram (*InputImpl->get_bin_image_ref (),
2849- InputImpl->get_context (), Devs);
2848+ ur_program_handle_t Prog =
2849+ createURProgram (*InputImpl->get_bin_image_ref (),
2850+ getSyclObjImpl (InputImpl->get_context ()), Devs);
28502851
28512852 if (InputImpl->get_bin_image_ref ()->supportsSpecConstants ())
28522853 setSpecializationConstants (InputImpl, Prog, Adapter);
@@ -3037,6 +3038,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
30373038 getSyclObjImpl (DevImgWithDeps.getMain ());
30383039
30393040 const context Context = MainInputImpl->get_context ();
3041+ const ContextImplPtr &ContextImpl = detail::getSyclObjImpl (Context);
30403042
30413043 std::vector<const RTDeviceBinaryImage *> BinImgs;
30423044 BinImgs.reserve (DevImgWithDeps.size ());
@@ -3065,7 +3067,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
30653067 }
30663068
30673069 ur_program_handle_t ResProgram = getBuiltURProgram (
3068- std::move (BinImgs), Context , Devs, &DevImgWithDeps, SpecConstBlob);
3070+ std::move (BinImgs), ContextImpl , Devs, &DevImgWithDeps, SpecConstBlob);
30693071
30703072 DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
30713073 MainInputImpl->get_bin_image_ref (), Context, Devs,
@@ -3185,7 +3187,8 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
31853187
31863188 if constexpr (DbgProgMgr > 0 )
31873189 std::cerr << " >>> Adding the kernel to the cache.\n " ;
3188- auto Program = createURProgram (Img, Context, {Device});
3190+ const ContextImplPtr &ContextImpl = detail::getSyclObjImpl (Context);
3191+ auto Program = createURProgram (Img, ContextImpl, {Device});
31893192 auto DeviceImpl = detail::getSyclObjImpl (Device);
31903193 auto &Adapter = DeviceImpl->getAdapter ();
31913194 UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
@@ -3200,8 +3203,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
32003203 std::vector<ur_program_handle_t > ExtraProgramsToLink;
32013204 std::vector<ur_device_handle_t > Devs = {DeviceImpl->getHandleRef ()};
32023205 auto BuildProgram =
3203- build (std::move (ProgramManaged), detail::getSyclObjImpl (Context),
3204- CompileOpts, LinkOpts, Devs,
3206+ build (std::move (ProgramManaged), ContextImpl, CompileOpts, LinkOpts, Devs,
32053207 /* For non SPIR-V devices DeviceLibReqdMask is always 0*/ 0 ,
32063208 ExtraProgramsToLink);
32073209 ur_kernel_handle_t UrKernel{nullptr };
0 commit comments