@@ -257,11 +257,11 @@ class device_image_impl
257
257
device_image_impl (const RTDeviceBinaryImage *BinImage, context Context,
258
258
devices_range Devices, bundle_state State,
259
259
std::shared_ptr<std::vector<kernel_id>> KernelIDs,
260
- ur_program_handle_t Program, uint8_t Origins, private_tag)
260
+ Managed<ur_program_handle_t > &&Program, uint8_t Origins,
261
+ private_tag)
261
262
: MBinImage(BinImage), MContext(std::move(Context)),
262
263
MDevices (Devices.to<std::vector<device_impl *>>()), MState(State),
263
- MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
264
- MKernelIDs(std::move(KernelIDs)),
264
+ MProgram(std::move(Program)), MKernelIDs(std::move(KernelIDs)),
265
265
MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(Origins) {
266
266
updateSpecConstSymMap ();
267
267
if (BinImage && (MOrigins & ImageOriginSYCLBIN)) {
@@ -287,40 +287,23 @@ class device_image_impl
287
287
const RTDeviceBinaryImage *BinImage, const context &Context,
288
288
devices_range Devices, bundle_state State,
289
289
std::shared_ptr<std::vector<kernel_id>> KernelIDs,
290
- ur_program_handle_t Program, const SpecConstMapT &SpecConstMap,
290
+ Managed< ur_program_handle_t > && Program, const SpecConstMapT &SpecConstMap,
291
291
const std::vector<unsigned char > &SpecConstsBlob, uint8_t Origins,
292
292
std::optional<KernelCompilerBinaryInfo> &&RTCInfo,
293
293
KernelNameSetT &&KernelNames,
294
294
KernelNameToArgMaskMap &&EliminatedKernelArgMasks,
295
295
std::unique_ptr<DynRTDeviceBinaryImage> &&MergedImageStorage, private_tag)
296
296
: MBinImage(BinImage), MContext(std::move(Context)),
297
297
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
298
- MProgram(Program, getSyclObjImpl(MContext)->getAdapter( )),
299
- MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move (KernelNames)},
298
+ MProgram(std::move( Program)), MKernelIDs(std::move(KernelIDs )),
299
+ MKernelNames{std::move (KernelNames)},
300
300
MEliminatedKernelArgMasks{std::move (EliminatedKernelArgMasks)},
301
301
MSpecConstsBlob (SpecConstsBlob),
302
302
MSpecConstsDefValBlob (getSpecConstsDefValBlob()),
303
303
MSpecConstSymMap (SpecConstMap), MOrigins(Origins),
304
304
MRTCBinInfo (std::move(RTCInfo)),
305
305
MMergedImageStorage (std::move(MergedImageStorage)) {}
306
306
307
- device_image_impl (const RTDeviceBinaryImage *BinImage, const context &Context,
308
- devices_range Devices, bundle_state State,
309
- ur_program_handle_t Program, syclex::source_language Lang,
310
- KernelNameSetT &&KernelNames,
311
- KernelNameToArgMaskMap &&EliminatedKernelArgMasks,
312
- private_tag)
313
- : MBinImage(BinImage), MContext(std::move(Context)),
314
- MDevices (Devices.to<std::vector<device_impl *>>()), MState(State),
315
- MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
316
- MKernelNames{std::move (KernelNames)},
317
- MEliminatedKernelArgMasks{std::move (EliminatedKernelArgMasks)},
318
- MSpecConstsDefValBlob (getSpecConstsDefValBlob()),
319
- MOrigins (ImageOriginKernelCompiler),
320
- MRTCBinInfo (KernelCompilerBinaryInfo{Lang}) {
321
- updateSpecConstSymMap ();
322
- }
323
-
324
307
device_image_impl (
325
308
const RTDeviceBinaryImage *BinImage, const context &Context,
326
309
devices_range Devices, bundle_state State,
@@ -366,14 +349,13 @@ class device_image_impl
366
349
}
367
350
368
351
device_image_impl (const context &Context, devices_range Devices,
369
- bundle_state State, ur_program_handle_t Program,
352
+ bundle_state State, Managed< ur_program_handle_t > && Program,
370
353
syclex::source_language Lang, KernelNameSetT &&KernelNames,
371
354
private_tag)
372
355
: MBinImage(static_cast <const RTDeviceBinaryImage *>(nullptr )),
373
356
MContext(std::move(Context)),
374
357
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
375
- MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
376
- MKernelNames{std::move (KernelNames)},
358
+ MProgram(std::move(Program)), MKernelNames{std::move (KernelNames)},
377
359
MSpecConstsDefValBlob (getSpecConstsDefValBlob()),
378
360
MOrigins(ImageOriginKernelCompiler),
379
361
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {}
@@ -771,14 +753,14 @@ class device_image_impl
771
753
772
754
auto DeviceVec = Devices.to <std::vector<ur_device_handle_t >>();
773
755
774
- ur_program_handle_t UrProgram = nullptr ;
756
+ Managed< ur_program_handle_t > UrProgram;
775
757
// SourceStrPtr will be null when source is Spir-V bytes.
776
758
const std::string *SourceStrPtr = std::get_if<std::string>(&MBinImage);
777
- bool FetchedFromCache = false ;
778
759
if (PersistentDeviceCodeCache::isEnabled () && SourceStrPtr) {
779
- FetchedFromCache = extKernelCompilerFetchFromCache (
780
- Devices, BuildOptions, *SourceStrPtr, UrProgram );
760
+ UrProgram =
761
+ extKernelCompilerFetchFromCache ( Devices, BuildOptions, *SourceStrPtr);
781
762
}
763
+ bool FetchedFromCache = (UrProgram != nullptr );
782
764
783
765
adapter_impl &Adapter = ContextImpl.getAdapter ();
784
766
@@ -813,7 +795,7 @@ class device_image_impl
813
795
}
814
796
return std::vector<std::shared_ptr<device_image_impl>>{
815
797
device_image_impl::create (MContext, Devices, bundle_state::executable,
816
- UrProgram, MRTCBinInfo->MLanguage ,
798
+ std::move ( UrProgram) , MRTCBinInfo->MLanguage ,
817
799
std::move (KernelNameSet))};
818
800
}
819
801
@@ -907,10 +889,10 @@ class device_image_impl
907
889
return SS.str ();
908
890
}
909
891
910
- bool extKernelCompilerFetchFromCache (
892
+ Managed< ur_program_handle_t > extKernelCompilerFetchFromCache (
911
893
devices_range Devices,
912
894
const std::vector<sycl::detail::string_view> &BuildOptions,
913
- const std::string &SourceStr, ur_program_handle_t &UrProgram ) const {
895
+ const std::string &SourceStr) const {
914
896
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl (MContext);
915
897
adapter_impl &Adapter = ContextImpl.getAdapter ();
916
898
@@ -924,7 +906,7 @@ class device_image_impl
924
906
PersistentDeviceCodeCache::getCompiledKernelFromDisc (Devices, UserArgs,
925
907
SourceStr);
926
908
if (BinProgs.empty ()) {
927
- return false ;
909
+ return {} ;
928
910
}
929
911
for (auto &BinProg : BinProgs) {
930
912
Binaries.push_back ((uint8_t *)(BinProg.data ()));
@@ -937,11 +919,12 @@ class device_image_impl
937
919
Properties.count = 0 ;
938
920
Properties.pMetadatas = nullptr ;
939
921
922
+ Managed<ur_program_handle_t > UrProgram{Adapter};
940
923
Adapter.call <UrApiKind::urProgramCreateWithBinary>(
941
924
ContextImpl.getHandleRef (), DeviceHandles.size (), DeviceHandles.data (),
942
925
Lengths.data (), Binaries.data (), &Properties, &UrProgram);
943
926
944
- return true ;
927
+ return UrProgram ;
945
928
}
946
929
947
930
// Get the specialization constant default value blob.
@@ -1226,7 +1209,7 @@ class device_image_impl
1226
1209
return Result;
1227
1210
}
1228
1211
1229
- ur_program_handle_t
1212
+ Managed< ur_program_handle_t >
1230
1213
createProgramFromSource (devices_range Devices,
1231
1214
const std::vector<sycl::detail::string_view> &Options,
1232
1215
std::string *LogPtr) const {
@@ -1266,11 +1249,10 @@ class device_image_impl
1266
1249
" languages at this time" );
1267
1250
}();
1268
1251
1269
- ur_program_handle_t UrProgram = nullptr ;
1252
+ Managed< ur_program_handle_t > UrProgram{Adapter} ;
1270
1253
Adapter.call <UrApiKind::urProgramCreateWithIL>(ContextImpl.getHandleRef (),
1271
1254
spirv.data (), spirv.size (),
1272
1255
nullptr , &UrProgram);
1273
- // program created by urProgramCreateWithIL is implicitly retained.
1274
1256
if (UrProgram == nullptr )
1275
1257
throw sycl::exception (
1276
1258
sycl::make_error_code (errc::invalid),
0 commit comments