Skip to content

Commit 9f771c1

Browse files
committed
Improve per suggestion
1 parent 63091d7 commit 9f771c1

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

unified-runtime/source/adapters/cuda/device.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@ struct ur_device_handle_t_ {
3737
bool MaxLocalMemSizeChosen{false};
3838
uint32_t NumComputeUnits{0};
3939
std::once_flag NVMLInitFlag;
40-
bool NVMLUsed{false};
40+
std::optional<nvmlDevice_t> NVMLDevice;
4141

4242
public:
4343
ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
4444
ur_platform_handle_t platform, uint32_t DevIndex)
4545
: CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase), RefCount{1},
4646
Platform(platform), DeviceIndex{DevIndex} {
47-
4847
UR_CHECK_ERROR(cuDeviceGetAttribute(
4948
&MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
5049
cuDevice));
@@ -104,7 +103,7 @@ struct ur_device_handle_t_ {
104103
if (MemoryProviderShared) {
105104
umfMemoryProviderDestroy(MemoryProviderShared);
106105
}
107-
if (NVMLUsed) {
106+
if (NVMLDevice.has_value()) {
108107
UR_CHECK_ERROR(nvmlShutdown());
109108
}
110109
cuDevicePrimaryCtxRelease(CuDevice);
@@ -119,11 +118,11 @@ struct ur_device_handle_t_ {
119118
// left, resources will be released.
120119
std::call_once(NVMLInitFlag, [this]() {
121120
UR_CHECK_ERROR(nvmlInit());
122-
NVMLUsed = true;
121+
nvmlDevice_t Handle;
122+
UR_CHECK_ERROR(nvmlDeviceGetHandleByIndex(DeviceIndex, &Handle));
123+
NVMLDevice = Handle;
123124
});
124-
nvmlDevice_t NVMLDevice;
125-
UR_CHECK_ERROR(nvmlDeviceGetHandleByIndex(DeviceIndex, &NVMLDevice));
126-
return NVMLDevice;
125+
return NVMLDevice.value();
127126
};
128127

129128
CUcontext getNativeContext() const noexcept { return CuContext; };

0 commit comments

Comments
 (0)