Skip to content

Commit 6653a66

Browse files
[SYCL] Complete transition to Managed<ur_program_handle_t> RAII model (#19557)
Started in #19536. Non-NFC because fixes a few resource leaks as can be seen in the updated test.
1 parent 7b5838c commit 6653a66

File tree

10 files changed

+194
-203
lines changed

10 files changed

+194
-203
lines changed

sycl/source/backend.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
196196
adapter_impl &Adapter = getAdapter(Backend);
197197
context_impl &ContextImpl = *getSyclObjImpl(TargetContext);
198198

199-
ur_program_handle_t UrProgram = nullptr;
199+
Managed<ur_program_handle_t> UrProgram{Adapter};
200200
ur_program_native_properties_t Properties{};
201201
Properties.stype = UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES;
202202
Properties.isNativeHandleOwned = !KeepOwnership;
@@ -258,18 +258,19 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
258258
"Program and kernel_bundle state mismatch " +
259259
detail::codeToString(UR_RESULT_ERROR_INVALID_VALUE));
260260
if (State == bundle_state::executable) {
261-
ur_program_handle_t UrLinkedProgram = nullptr;
261+
Managed<ur_program_handle_t> UrLinkedProgram{Adapter};
262+
ur_program_handle_t ProgramsToLink[] = {UrProgram};
262263
auto Res = Adapter.call_nocheck<UrApiKind::urProgramLinkExp>(
263-
ContextImpl.getHandleRef(), 1u, &Dev, 1u, &UrProgram, nullptr,
264+
ContextImpl.getHandleRef(), 1u, &Dev, 1u, ProgramsToLink, nullptr,
264265
&UrLinkedProgram);
265266
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
266267
Res = Adapter.call_nocheck<UrApiKind::urProgramLink>(
267-
ContextImpl.getHandleRef(), 1u, &UrProgram, nullptr,
268+
ContextImpl.getHandleRef(), 1u, ProgramsToLink, nullptr,
268269
&UrLinkedProgram);
269270
}
270271
Adapter.checkUrResult<errc::build>(Res);
271272
if (UrLinkedProgram != nullptr) {
272-
UrProgram = UrLinkedProgram;
273+
UrProgram = std::move(UrLinkedProgram);
273274
}
274275
}
275276
break;
@@ -301,9 +302,9 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
301302
// do the same to user images, since they may contain references to undefined
302303
// symbols (e.g. when kernel_bundle is supposed to be joined with another).
303304
auto KernelIDs = std::make_shared<std::vector<kernel_id>>();
304-
auto DevImgImpl =
305-
device_image_impl::create(nullptr, TargetContext, Devices, State,
306-
KernelIDs, UrProgram, ImageOriginInterop);
305+
auto DevImgImpl = device_image_impl::create(
306+
nullptr, TargetContext, Devices, State, KernelIDs, std::move(UrProgram),
307+
ImageOriginInterop);
307308
device_image_plain DevImg{DevImgImpl};
308309

309310
return kernel_bundle_impl::create(TargetContext, Devices, DevImg);

sycl/source/detail/adapter_impl.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ template <typename URResource> class Managed {
244244
if constexpr (std::is_same_v<URResource, ur_program_handle_t>)
245245
return UrApiKind::urProgramRelease;
246246
}();
247+
static constexpr auto Retain = []() constexpr {
248+
if constexpr (std::is_same_v<URResource, ur_program_handle_t>)
249+
return UrApiKind::urProgramRetain;
250+
}();
247251

248252
public:
249253
Managed() = default;
@@ -258,6 +262,7 @@ template <typename URResource> class Managed {
258262
Managed &operator=(Managed &&Other) {
259263
if (R)
260264
Adapter->call<Release>(R);
265+
261266
R = Other.R;
262267
Other.R = nullptr;
263268
Adapter = Other.Adapter;
@@ -285,6 +290,18 @@ template <typename URResource> class Managed {
285290
Adapter->call<Release>(R);
286291
}
287292

293+
Managed retain() {
294+
assert(R && "Cannot retain unintialized resource!");
295+
Adapter->call<Retain>(R);
296+
return Managed{R, *Adapter};
297+
}
298+
299+
bool operator==(const Managed &Other) const {
300+
assert((!Adapter || !Other.Adapter || Adapter == Other.Adapter) &&
301+
"Objects must belong to the same adapter!");
302+
return R == Other.R;
303+
}
304+
288305
private:
289306
URResource R = nullptr;
290307
adapter_impl *Adapter = nullptr;

sycl/source/detail/device_image_impl.hpp

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,11 @@ class device_image_impl
257257
device_image_impl(const RTDeviceBinaryImage *BinImage, context Context,
258258
devices_range Devices, bundle_state State,
259259
std::shared_ptr<std::vector<kernel_id>> KernelIDs,
260-
ur_program_handle_t Program, uint8_t Origins, private_tag)
260+
Managed<ur_program_handle_t> &&Program, uint8_t Origins,
261+
private_tag)
261262
: MBinImage(BinImage), MContext(std::move(Context)),
262263
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
263-
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
264-
MKernelIDs(std::move(KernelIDs)),
264+
MProgram(std::move(Program)), MKernelIDs(std::move(KernelIDs)),
265265
MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(Origins) {
266266
updateSpecConstSymMap();
267267
if (BinImage && (MOrigins & ImageOriginSYCLBIN)) {
@@ -287,40 +287,23 @@ class device_image_impl
287287
const RTDeviceBinaryImage *BinImage, const context &Context,
288288
devices_range Devices, bundle_state State,
289289
std::shared_ptr<std::vector<kernel_id>> KernelIDs,
290-
ur_program_handle_t Program, const SpecConstMapT &SpecConstMap,
290+
Managed<ur_program_handle_t> &&Program, const SpecConstMapT &SpecConstMap,
291291
const std::vector<unsigned char> &SpecConstsBlob, uint8_t Origins,
292292
std::optional<KernelCompilerBinaryInfo> &&RTCInfo,
293293
KernelNameSetT &&KernelNames,
294294
KernelNameToArgMaskMap &&EliminatedKernelArgMasks,
295295
std::unique_ptr<DynRTDeviceBinaryImage> &&MergedImageStorage, private_tag)
296296
: MBinImage(BinImage), MContext(std::move(Context)),
297297
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
298-
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
299-
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
298+
MProgram(std::move(Program)), MKernelIDs(std::move(KernelIDs)),
299+
MKernelNames{std::move(KernelNames)},
300300
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
301301
MSpecConstsBlob(SpecConstsBlob),
302302
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
303303
MSpecConstSymMap(SpecConstMap), MOrigins(Origins),
304304
MRTCBinInfo(std::move(RTCInfo)),
305305
MMergedImageStorage(std::move(MergedImageStorage)) {}
306306

307-
device_image_impl(const RTDeviceBinaryImage *BinImage, const context &Context,
308-
devices_range Devices, bundle_state State,
309-
ur_program_handle_t Program, syclex::source_language Lang,
310-
KernelNameSetT &&KernelNames,
311-
KernelNameToArgMaskMap &&EliminatedKernelArgMasks,
312-
private_tag)
313-
: MBinImage(BinImage), MContext(std::move(Context)),
314-
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
315-
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
316-
MKernelNames{std::move(KernelNames)},
317-
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
318-
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
319-
MOrigins(ImageOriginKernelCompiler),
320-
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {
321-
updateSpecConstSymMap();
322-
}
323-
324307
device_image_impl(
325308
const RTDeviceBinaryImage *BinImage, const context &Context,
326309
devices_range Devices, bundle_state State,
@@ -366,14 +349,13 @@ class device_image_impl
366349
}
367350

368351
device_image_impl(const context &Context, devices_range Devices,
369-
bundle_state State, ur_program_handle_t Program,
352+
bundle_state State, Managed<ur_program_handle_t> &&Program,
370353
syclex::source_language Lang, KernelNameSetT &&KernelNames,
371354
private_tag)
372355
: MBinImage(static_cast<const RTDeviceBinaryImage *>(nullptr)),
373356
MContext(std::move(Context)),
374357
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
375-
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
376-
MKernelNames{std::move(KernelNames)},
358+
MProgram(std::move(Program)), MKernelNames{std::move(KernelNames)},
377359
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
378360
MOrigins(ImageOriginKernelCompiler),
379361
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {}
@@ -771,14 +753,14 @@ class device_image_impl
771753

772754
auto DeviceVec = Devices.to<std::vector<ur_device_handle_t>>();
773755

774-
ur_program_handle_t UrProgram = nullptr;
756+
Managed<ur_program_handle_t> UrProgram;
775757
// SourceStrPtr will be null when source is Spir-V bytes.
776758
const std::string *SourceStrPtr = std::get_if<std::string>(&MBinImage);
777-
bool FetchedFromCache = false;
778759
if (PersistentDeviceCodeCache::isEnabled() && SourceStrPtr) {
779-
FetchedFromCache = extKernelCompilerFetchFromCache(
780-
Devices, BuildOptions, *SourceStrPtr, UrProgram);
760+
UrProgram =
761+
extKernelCompilerFetchFromCache(Devices, BuildOptions, *SourceStrPtr);
781762
}
763+
bool FetchedFromCache = (UrProgram != nullptr);
782764

783765
adapter_impl &Adapter = ContextImpl.getAdapter();
784766

@@ -813,7 +795,7 @@ class device_image_impl
813795
}
814796
return std::vector<std::shared_ptr<device_image_impl>>{
815797
device_image_impl::create(MContext, Devices, bundle_state::executable,
816-
UrProgram, MRTCBinInfo->MLanguage,
798+
std::move(UrProgram), MRTCBinInfo->MLanguage,
817799
std::move(KernelNameSet))};
818800
}
819801

@@ -907,10 +889,10 @@ class device_image_impl
907889
return SS.str();
908890
}
909891

910-
bool extKernelCompilerFetchFromCache(
892+
Managed<ur_program_handle_t> extKernelCompilerFetchFromCache(
911893
devices_range Devices,
912894
const std::vector<sycl::detail::string_view> &BuildOptions,
913-
const std::string &SourceStr, ur_program_handle_t &UrProgram) const {
895+
const std::string &SourceStr) const {
914896
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
915897
adapter_impl &Adapter = ContextImpl.getAdapter();
916898

@@ -924,7 +906,7 @@ class device_image_impl
924906
PersistentDeviceCodeCache::getCompiledKernelFromDisc(Devices, UserArgs,
925907
SourceStr);
926908
if (BinProgs.empty()) {
927-
return false;
909+
return {};
928910
}
929911
for (auto &BinProg : BinProgs) {
930912
Binaries.push_back((uint8_t *)(BinProg.data()));
@@ -937,11 +919,12 @@ class device_image_impl
937919
Properties.count = 0;
938920
Properties.pMetadatas = nullptr;
939921

922+
Managed<ur_program_handle_t> UrProgram{Adapter};
940923
Adapter.call<UrApiKind::urProgramCreateWithBinary>(
941924
ContextImpl.getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
942925
Lengths.data(), Binaries.data(), &Properties, &UrProgram);
943926

944-
return true;
927+
return UrProgram;
945928
}
946929

947930
// Get the specialization constant default value blob.
@@ -1226,7 +1209,7 @@ class device_image_impl
12261209
return Result;
12271210
}
12281211

1229-
ur_program_handle_t
1212+
Managed<ur_program_handle_t>
12301213
createProgramFromSource(devices_range Devices,
12311214
const std::vector<sycl::detail::string_view> &Options,
12321215
std::string *LogPtr) const {
@@ -1266,11 +1249,10 @@ class device_image_impl
12661249
"languages at this time");
12671250
}();
12681251

1269-
ur_program_handle_t UrProgram = nullptr;
1252+
Managed<ur_program_handle_t> UrProgram{Adapter};
12701253
Adapter.call<UrApiKind::urProgramCreateWithIL>(ContextImpl.getHandleRef(),
12711254
spirv.data(), spirv.size(),
12721255
nullptr, &UrProgram);
1273-
// program created by urProgramCreateWithIL is implicitly retained.
12741256
if (UrProgram == nullptr)
12751257
throw sycl::exception(
12761258
sycl::make_error_code(errc::invalid),

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,8 @@ class kernel_bundle_impl
597597
for (const detail::RTDeviceBinaryImage *Image : BestImages)
598598
MDeviceImages.emplace_back(device_image_impl::create(
599599
Image, Context, Devs, ProgramManager::getBinImageState(Image),
600-
/*KernelIDs=*/nullptr, /*URProgram=*/nullptr, ImageOriginSYCLBIN));
600+
/*KernelIDs=*/nullptr, Managed<ur_program_handle_t>{},
601+
ImageOriginSYCLBIN));
601602
ProgramManager::getInstance().bringSYCLDeviceImagesToState(MDeviceImages,
602603
State);
603604
fillUniqueDeviceImages();

sycl/source/detail/kernel_name_based_cache_t.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ struct FastKernelCacheVal {
3636

3737
FastKernelCacheVal(ur_kernel_handle_t KernelHandle, std::mutex *Mutex,
3838
const KernelArgMask *KernelArgMask,
39-
ur_program_handle_t ProgramHandle, adapter_impl &Adapter)
39+
Managed<ur_program_handle_t> &&ProgramHandle,
40+
adapter_impl &Adapter)
4041
: MKernelHandle(KernelHandle), MMutex(Mutex),
41-
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle, Adapter),
42+
MKernelArgMask(KernelArgMask), MProgramHandle(std::move(ProgramHandle)),
4243
MAdapter(Adapter) {}
4344

4445
~FastKernelCacheVal() {

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,11 @@ class KernelProgramCache {
112112
};
113113

114114
struct ProgramBuildResult : public BuildResult<Managed<ur_program_handle_t>> {
115-
ProgramBuildResult(adapter_impl &Adapter) {
116-
Val = Managed<ur_program_handle_t>{Adapter};
117-
}
118-
ProgramBuildResult(adapter_impl &Adapter, BuildState InitialState) {
119-
Val = Managed<ur_program_handle_t>{Adapter};
115+
ProgramBuildResult() = default;
116+
ProgramBuildResult(BuildState InitialState,
117+
Managed<ur_program_handle_t> &&Prog) {
120118
this->State.store(InitialState);
119+
this->Val = std::move(Prog);
121120
}
122121
#ifdef _MSC_VER
123122
#pragma warning(push)
@@ -407,7 +406,7 @@ class KernelProgramCache {
407406
ProgramCache &ProgCache = LockedCache.get();
408407
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
409408
if (DidInsert) {
410-
It->second = std::make_shared<ProgramBuildResult>(getAdapter());
409+
It->second = std::make_shared<ProgramBuildResult>();
411410
// Save reference between the common key and the full key.
412411
CommonProgramKeyT CommonKey =
413412
std::make_pair(CacheKey.first.second, CacheKey.second);
@@ -424,14 +423,13 @@ class KernelProgramCache {
424423
//
425424
// Returns whether or not an insertion took place.
426425
bool insertBuiltProgram(const ProgramCacheKeyT &CacheKey,
427-
ur_program_handle_t Program) {
426+
Managed<ur_program_handle_t> &Program) {
428427
auto LockedCache = acquireCachedPrograms();
429428
ProgramCache &ProgCache = LockedCache.get();
430429
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
431430
if (DidInsert) {
432-
It->second = std::make_shared<ProgramBuildResult>(getAdapter(),
433-
BuildState::BS_Done);
434-
It->second->Val = Managed<ur_program_handle_t>{Program, getAdapter()};
431+
It->second = std::make_shared<ProgramBuildResult>(BuildState::BS_Done,
432+
Program.retain());
435433
// Save reference between the common key and the full key.
436434
CommonProgramKeyT CommonKey =
437435
std::make_pair(CacheKey.first.second, CacheKey.second);
@@ -643,8 +641,7 @@ class KernelProgramCache {
643641
// If it is the first time the program is fetched, add it to the eviction
644642
// list.
645643
void registerProgramFetch(const ProgramCacheKeyT &CacheKey,
646-
const ur_program_handle_t &Program,
647-
const bool IsBuilt) {
644+
ur_program_handle_t Program, const bool IsBuilt) {
648645

649646
size_t ProgramCacheEvictionThreshold =
650647
SYCLConfig<SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::getProgramCacheSize();
@@ -799,9 +796,10 @@ class KernelProgramCache {
799796

800797
// only the building thread will run this
801798
try {
802-
// Remove `adapter_impl` from `ProgramBuildResult`'s ctors once `Build`
803-
// returns `Managed<ur_platform_handle_t`:
804-
*(&BuildResult->Val) = Build();
799+
static_assert(
800+
std::is_same_v<decltype(Build()), decltype(BuildResult->Val)>,
801+
"Are we casting from Managed<URResource> to plain URResource?");
802+
BuildResult->Val = Build();
805803

806804
if constexpr (!std::is_same_v<EvictFT, void *>)
807805
EvictFunc(BuildResult->Val, /*IsBuilt=*/true);

0 commit comments

Comments
 (0)