Skip to content

Commit 3fa14e0

Browse files
[NFC][SYCL] Better "managed" ur_program_handle_t (#19536)
There was `ProgramManager::ProgramPtr` alias over `std::unique_ptr` with a custom deleter to RAII-manage `ur_program_handle_t` lifetime but it was applied in just a few places with the rest of the usage left with C-style explicit management. This PR introduce a dedicated helper class to manage all UR handle types that I think is more convenient than `ProgramManager::ProgramPtr`. I'm also switching all the objects that stored `ur_program_handle_t` and then `urProgramRelease`d them to use that new helper, while leaving the full refactoring (i.e., create those `Managed` objects at `urProgramCreate*`/`urProgramRetain` point) to a subsequent PRs to ease review process. Other `ur*_handle_t`s are left to subsequent changes as well.
1 parent 9dd1952 commit 3fa14e0

15 files changed

+177
-135
lines changed

sycl/source/backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ kernel make_kernel(const context &TargetContext,
343343
const device_image<bundle_state::executable> &DeviceImage =
344344
*KernelBundle.begin();
345345
device_image_impl &DeviceImageImpl = *getSyclObjImpl(DeviceImage);
346-
UrProgram = DeviceImageImpl.get_ur_program_ref();
346+
UrProgram = DeviceImageImpl.get_ur_program();
347347
}
348348

349349
// Create UR kernel first.

sycl/source/detail/adapter_impl.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,56 @@ class adapter_impl {
239239
UrFuncPtrMapT UrFuncPtrs;
240240
}; // class adapter_impl
241241

242+
template <typename URResource> class Managed {
243+
static constexpr auto Release = []() constexpr {
244+
if constexpr (std::is_same_v<URResource, ur_program_handle_t>)
245+
return UrApiKind::urProgramRelease;
246+
}();
247+
248+
public:
249+
Managed() = default;
250+
Managed(URResource R, adapter_impl &Adapter) : R(R), Adapter(&Adapter) {}
251+
Managed(adapter_impl &Adapter) : Adapter(&Adapter) {}
252+
Managed(const Managed &) = delete;
253+
Managed(Managed &&Other) : Adapter(Other.Adapter) {
254+
R = Other.R;
255+
Other.R = nullptr;
256+
}
257+
Managed &operator=(const Managed &) = delete;
258+
Managed &operator=(Managed &&Other) {
259+
if (R)
260+
Adapter->call<Release>(R);
261+
R = Other.R;
262+
Other.R = nullptr;
263+
Adapter = Other.Adapter;
264+
return *this;
265+
}
266+
267+
operator URResource() const { return R; }
268+
269+
URResource release() {
270+
URResource Res = R;
271+
R = nullptr;
272+
return Res;
273+
}
274+
275+
URResource *operator&() {
276+
assert(!R && "Already initialized!");
277+
assert(Adapter && "Adapter must be set for this API!");
278+
return &R;
279+
}
280+
281+
~Managed() {
282+
if (!R)
283+
return;
284+
285+
Adapter->call<Release>(R);
286+
}
287+
288+
private:
289+
URResource R = nullptr;
290+
adapter_impl *Adapter = nullptr;
291+
};
242292
} // namespace detail
243293
} // namespace _V1
244294
} // namespace sycl

sycl/source/detail/context_impl.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,7 @@ context_impl::~context_impl() {
128128
DeviceGlobal);
129129
DGEntry->removeAssociatedResources(this);
130130
}
131-
for (auto LibProg : MCachedLibPrograms) {
132-
assert(LibProg.second && "Null program must not be kept in the cache");
133-
getAdapter().call<UrApiKind::urProgramRelease>(LibProg.second);
134-
}
131+
MCachedLibPrograms.clear();
135132
// TODO catch an exception and put it to list of asynchronous exceptions
136133
getAdapter().call_nocheck<UrApiKind::urContextRelease>(MContext);
137134
} catch (std::exception &e) {

sycl/source/detail/context_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
134134

135135
using CachedLibProgramsT =
136136
std::map<std::pair<DeviceLibExt, ur_device_handle_t>,
137-
ur_program_handle_t>;
137+
Managed<ur_program_handle_t>>;
138138

139139
/// In contrast to user programs, which are compiled from user code, library
140140
/// programs come from the SYCL runtime. They are identified by the

sycl/source/detail/device_image_impl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
3030
if (!KID || !has_kernel(*KID))
3131
continue;
3232

33-
auto UrProgram = get_ur_program_ref();
33+
auto UrProgram = get_ur_program();
3434
auto [UrKernel, CacheMutex, ArgMask] =
3535
PM.getOrCreateKernel(Context, AdjustedName,
3636
/*PropList=*/{}, UrProgram);
@@ -41,7 +41,7 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
4141
return nullptr;
4242
}
4343

44-
ur_program_handle_t UrProgram = get_ur_program_ref();
44+
ur_program_handle_t UrProgram = get_ur_program();
4545
detail::adapter_impl &Adapter = getSyclObjImpl(Context)->getAdapter();
4646
ur_kernel_handle_t UrKernel = nullptr;
4747
Adapter.call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),

sycl/source/detail/device_image_impl.hpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ class device_image_impl
260260
ur_program_handle_t Program, uint8_t Origins, private_tag)
261261
: MBinImage(BinImage), MContext(std::move(Context)),
262262
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
263-
MProgram(Program), MKernelIDs(std::move(KernelIDs)),
263+
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
264+
MKernelIDs(std::move(KernelIDs)),
264265
MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(Origins) {
265266
updateSpecConstSymMap();
266267
if (BinImage && (MOrigins & ImageOriginSYCLBIN)) {
@@ -294,8 +295,8 @@ class device_image_impl
294295
std::unique_ptr<DynRTDeviceBinaryImage> &&MergedImageStorage, private_tag)
295296
: MBinImage(BinImage), MContext(std::move(Context)),
296297
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
297-
MProgram(Program), MKernelIDs(std::move(KernelIDs)),
298-
MKernelNames{std::move(KernelNames)},
298+
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
299+
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
299300
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
300301
MSpecConstsBlob(SpecConstsBlob),
301302
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
@@ -311,7 +312,8 @@ class device_image_impl
311312
private_tag)
312313
: MBinImage(BinImage), MContext(std::move(Context)),
313314
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
314-
MProgram(Program), MKernelNames{std::move(KernelNames)},
315+
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
316+
MKernelNames{std::move(KernelNames)},
315317
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
316318
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
317319
MOrigins(ImageOriginKernelCompiler),
@@ -329,8 +331,7 @@ class device_image_impl
329331
private_tag)
330332
: MBinImage(BinImage), MContext(std::move(Context)),
331333
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
332-
MProgram(nullptr), MKernelIDs(std::move(KernelIDs)),
333-
MKernelNames{std::move(KernelNames)},
334+
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
334335
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
335336
MOrigins(ImageOriginKernelCompiler),
336337
MRTCBinInfo(KernelCompilerBinaryInfo{
@@ -344,7 +345,7 @@ class device_image_impl
344345
include_pairs_t &&IncludePairsVec, private_tag)
345346
: MBinImage(Src), MContext(std::move(Context)),
346347
MDevices(Devices.to<std::vector<device_impl *>>()),
347-
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
348+
MState(bundle_state::ext_oneapi_source),
348349
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
349350
MOrigins(ImageOriginKernelCompiler),
350351
MRTCBinInfo(
@@ -357,7 +358,7 @@ class device_image_impl
357358
private_tag)
358359
: MBinImage(Bytes), MContext(std::move(Context)),
359360
MDevices(Devices.to<std::vector<device_impl *>>()),
360-
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
361+
MState(bundle_state::ext_oneapi_source),
361362
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
362363
MOrigins(ImageOriginKernelCompiler),
363364
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {
@@ -371,7 +372,8 @@ class device_image_impl
371372
: MBinImage(static_cast<const RTDeviceBinaryImage *>(nullptr)),
372373
MContext(std::move(Context)),
373374
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
374-
MProgram(Program), MKernelNames{std::move(KernelNames)},
375+
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
376+
MKernelNames{std::move(KernelNames)},
375377
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
376378
MOrigins(ImageOriginKernelCompiler),
377379
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {}
@@ -558,9 +560,7 @@ class device_image_impl
558560
return get_devices().contains(Dev);
559561
}
560562

561-
const ur_program_handle_t &get_ur_program_ref() const noexcept {
562-
return MProgram;
563-
}
563+
ur_program_handle_t get_ur_program() const noexcept { return MProgram; }
564564

565565
const RTDeviceBinaryImage *const &get_bin_image_ref() const {
566566
return std::get<const RTDeviceBinaryImage *>(MBinImage);
@@ -617,21 +617,25 @@ class device_image_impl
617617
return NativeProgram;
618618
}
619619

620-
~device_image_impl() {
621-
try {
622-
if (MProgram) {
623-
adapter_impl &Adapter = getSyclObjImpl(MContext)->getAdapter();
624-
Adapter.call<UrApiKind::urProgramRelease>(MProgram);
625-
}
626-
if (MSpecConstsBuffer) {
627-
std::lock_guard<std::mutex> Lock{MSpecConstAccessMtx};
628-
adapter_impl &Adapter = getSyclObjImpl(MContext)->getAdapter();
629-
memReleaseHelper(Adapter, MSpecConstsBuffer);
630-
}
631-
} catch (std::exception &e) {
632-
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~device_image_impl", e);
620+
#ifdef _MSC_VER
621+
#pragma warning(push)
622+
// https://developercommunity.visualstudio.com/t/False-C4297-warning-while-using-function/1130300
623+
// https://godbolt.org/z/xsMvKf84f
624+
#pragma warning(disable : 4297)
625+
#endif
626+
~device_image_impl() try {
627+
if (MSpecConstsBuffer) {
628+
std::lock_guard<std::mutex> Lock{MSpecConstAccessMtx};
629+
adapter_impl &Adapter = getSyclObjImpl(MContext)->getAdapter();
630+
memReleaseHelper(Adapter, MSpecConstsBuffer);
633631
}
632+
} catch (std::exception &e) {
633+
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~device_image_impl", e);
634+
return; // Don't re-throw.
634635
}
636+
#ifdef _MSC_VER
637+
#pragma warning(pop)
638+
#endif
635639

636640
std::string adjustKernelName(std::string_view Name) const {
637641
if (MOrigins & ImageOriginSYCLBIN) {
@@ -1298,7 +1302,7 @@ class device_image_impl
12981302
std::vector<device_impl *> MDevices;
12991303
bundle_state MState;
13001304
// Native program handler which this device image represents
1301-
ur_program_handle_t MProgram = nullptr;
1305+
Managed<ur_program_handle_t> MProgram;
13021306

13031307
// List of kernel ids available in this image, elements should be sorted
13041308
// according to LessByNameComp. Shared between images for performance reasons

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,11 +991,11 @@ class kernel_bundle_impl
991991
auto [Kernel, CacheMutex, ArgMask] =
992992
detail::ProgramManager::getInstance().getOrCreateKernel(
993993
MContext, KernelID.get_name(), /*PropList=*/{},
994-
SelectedImage->get_ur_program_ref());
994+
SelectedImage->get_ur_program());
995995

996996
return std::make_shared<kernel_impl>(
997997
Kernel, *detail::getSyclObjImpl(MContext), SelectedImage, *this,
998-
ArgMask, SelectedImage->get_ur_program_ref(), CacheMutex);
998+
ArgMask, SelectedImage->get_ur_program(), CacheMutex);
999999
}
10001000

10011001
std::shared_ptr<kernel_impl>

sycl/source/detail/kernel_name_based_cache_t.hpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,26 @@ struct FastKernelCacheVal {
2727
caching is disabled, the pointer is
2828
nullptr. */
2929
const KernelArgMask *MKernelArgMask; /* Eliminated kernel argument mask. */
30-
ur_program_handle_t MProgramHandle; /* UR program handle corresponding to
31-
this kernel. */
32-
const adapter_impl &MAdapterPtr; /* We can keep reference to the adapter
33-
because during 2-stage shutdown the kernel
34-
cache is destroyed deliberately before the
35-
adapter. */
30+
Managed<ur_program_handle_t> MProgramHandle; /* UR program handle
31+
corresponding to this kernel. */
32+
adapter_impl &MAdapter; /* We can keep reference to the adapter
33+
because during 2-stage shutdown the kernel
34+
cache is destroyed deliberately before the
35+
adapter. */
3636

3737
FastKernelCacheVal(ur_kernel_handle_t KernelHandle, std::mutex *Mutex,
3838
const KernelArgMask *KernelArgMask,
39-
ur_program_handle_t ProgramHandle,
40-
const adapter_impl &AdapterPtr)
39+
ur_program_handle_t ProgramHandle, adapter_impl &Adapter)
4140
: MKernelHandle(KernelHandle), MMutex(Mutex),
42-
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle),
43-
MAdapterPtr(AdapterPtr) {}
41+
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle, Adapter),
42+
MAdapter(Adapter) {}
4443

4544
~FastKernelCacheVal() {
4645
if (MKernelHandle)
47-
MAdapterPtr.call<sycl::detail::UrApiKind::urKernelRelease>(MKernelHandle);
48-
if (MProgramHandle)
49-
MAdapterPtr.call<sycl::detail::UrApiKind::urProgramRelease>(
50-
MProgramHandle);
46+
MAdapter.call<sycl::detail::UrApiKind::urKernelRelease>(MKernelHandle);
5147
MKernelHandle = nullptr;
5248
MMutex = nullptr;
5349
MKernelArgMask = nullptr;
54-
MProgramHandle = nullptr;
5550
}
5651

5752
FastKernelCacheVal(const FastKernelCacheVal &) = delete;

sycl/source/detail/kernel_program_cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
namespace sycl {
1313
inline namespace _V1 {
1414
namespace detail {
15-
const adapter_impl &KernelProgramCache::getAdapter() {
15+
adapter_impl &KernelProgramCache::getAdapter() {
1616
return MParentContext->getAdapter();
1717
}
1818

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,28 +111,28 @@ class KernelProgramCache {
111111
}
112112
};
113113

114-
struct ProgramBuildResult : public BuildResult<ur_program_handle_t> {
115-
const adapter_impl &MAdapter;
116-
ProgramBuildResult(const adapter_impl &Adapter) : MAdapter(Adapter) {
117-
Val = nullptr;
114+
struct ProgramBuildResult : public BuildResult<Managed<ur_program_handle_t>> {
115+
ProgramBuildResult(adapter_impl &Adapter) {
116+
Val = Managed<ur_program_handle_t>{Adapter};
118117
}
119-
ProgramBuildResult(const adapter_impl &Adapter, BuildState InitialState)
120-
: MAdapter(Adapter) {
121-
Val = nullptr;
118+
ProgramBuildResult(adapter_impl &Adapter, BuildState InitialState) {
119+
Val = Managed<ur_program_handle_t>{Adapter};
122120
this->State.store(InitialState);
123121
}
124-
~ProgramBuildResult() {
125-
try {
126-
if (Val) {
127-
ur_result_t Err =
128-
MAdapter.call_nocheck<UrApiKind::urProgramRelease>(Val);
129-
__SYCL_CHECK_UR_CODE_NO_EXC(Err, MAdapter.getBackend());
130-
}
131-
} catch (std::exception &e) {
132-
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ProgramBuildResult",
133-
e);
134-
}
122+
#ifdef _MSC_VER
123+
#pragma warning(push)
124+
// https://developercommunity.visualstudio.com/t/False-C4297-warning-while-using-function/1130300
125+
// https://godbolt.org/z/xsMvKf84f
126+
#pragma warning(disable : 4297)
127+
#endif
128+
~ProgramBuildResult() try {
129+
} catch (std::exception &e) {
130+
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ProgramBuildResult", e);
131+
return; // Don't re-throw.
135132
}
133+
#ifdef _MSC_VER
134+
#pragma warning(pop)
135+
#endif
136136
};
137137
using ProgramBuildResultPtr = std::shared_ptr<ProgramBuildResult>;
138138

@@ -434,7 +434,7 @@ class KernelProgramCache {
434434
if (DidInsert) {
435435
It->second = std::make_shared<ProgramBuildResult>(getAdapter(),
436436
BuildState::BS_Done);
437-
It->second->Val = Program;
437+
It->second->Val = Managed<ur_program_handle_t>{Program, getAdapter()};
438438
// Save reference between the common key and the full key.
439439
CommonProgramKeyT CommonKey =
440440
std::make_pair(CacheKey.first.second, CacheKey.second);
@@ -794,7 +794,9 @@ class KernelProgramCache {
794794

795795
// only the building thread will run this
796796
try {
797-
BuildResult->Val = Build();
797+
// Remove `adapter_impl` from `ProgramBuildResult`'s ctors once `Build`
798+
// returns `Managed<ur_platform_handle_t`:
799+
*(&BuildResult->Val) = Build();
798800

799801
if constexpr (!std::is_same_v<EvictFT, void *>)
800802
EvictFunc(BuildResult->Val, /*IsBuilt=*/true);
@@ -868,7 +870,7 @@ class KernelProgramCache {
868870

869871
friend class ::MockKernelProgramCache;
870872

871-
const adapter_impl &getAdapter();
873+
adapter_impl &getAdapter();
872874
ur_context_handle_t getURContext() const;
873875
};
874876
} // namespace detail

0 commit comments

Comments
 (0)