Skip to content

Commit 6611cc0

Browse files
committed
SWDEV-494149 - Improve hipGet/Set Device
Change-Id: If8975687a3ba9caadafc48a0066f19a4ebaab9e2
1 parent 2ca644c commit 6611cc0

File tree

3 files changed

+25
-17
lines changed

3 files changed

+25
-17
lines changed

hipamd/src/hip_context.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc.
1+
/* Copyright (c) 2015 - 2024 Advanced Micro Devices, Inc.
22
33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal
@@ -48,10 +48,11 @@ void init(bool* status) {
4848
}
4949
ClPrint(amd::LOG_INFO, amd::LOG_INIT, "Direct Dispatch: %d", AMD_DIRECT_DISPATCH);
5050

51-
5251
const std::vector<amd::Device*>& devices = amd::Device::getDevices(CL_DEVICE_TYPE_GPU, false);
52+
const size_t deviceCount = devices.size();
53+
g_devices.reserve(deviceCount); // Pre-allocate space for better performance
5354

54-
for (unsigned int i=0; i<devices.size(); i++) {
55+
for (unsigned int i = 0; i < deviceCount; i++) {
5556
// Enable active wait on the device by default
5657
devices[i]->SetActiveWait(true);
5758
// use the eternal contexts that already exist for new hip::Device's here

hipamd/src/hip_device_runtime.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,16 +656,17 @@ int ihipGetDevice() {
656656
hipError_t hipGetDevice(int* deviceId) {
657657
HIP_INIT_API(hipGetDevice, deviceId);
658658

659-
if (deviceId != nullptr) {
660-
int dev = ihipGetDevice();
661-
if (dev == -1) {
662-
HIP_RETURN(hipErrorNoDevice);
663-
}
664-
*deviceId = dev;
665-
HIP_RETURN(hipSuccess, *deviceId);
666-
} else {
659+
if (deviceId == nullptr) {
667660
HIP_RETURN(hipErrorInvalidValue);
668661
}
662+
663+
Device* device = hip::getCurrentDevice();
664+
if (device == nullptr) {
665+
HIP_RETURN(hipErrorNoDevice);
666+
}
667+
668+
*deviceId = device->deviceId();
669+
HIP_RETURN(hipSuccess, *deviceId);
669670
}
670671

671672
hipError_t hipGetDeviceCount(int* count) {
@@ -685,6 +686,12 @@ hipError_t hipGetDeviceFlags(unsigned int* flags) {
685686

686687
hipError_t hipSetDevice(int device) {
687688
HIP_INIT_API_NO_RETURN(hipSetDevice, device);
689+
690+
// Check if the device is already set
691+
if (hip::tls.device_ != nullptr && hip::tls.device_->deviceId() == device) {
692+
HIP_RETURN(hipSuccess);
693+
}
694+
688695
if (static_cast<unsigned int>(device) < g_devices.size()) {
689696
hip::setCurrentDevice(device);
690697

hipamd/src/hip_internal.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ const char* ihipGetErrorName(hipError_t hip_error);
112112
#define HIP_INIT(noReturn) \
113113
{ \
114114
bool status = true; \
115-
std::call_once(hip::g_ihipInitialized, hip::init, &status); \
115+
std::call_once(hip::g_ihipInitialized, hip::init, &status); \
116116
if (!status && !noReturn) { \
117117
HIP_RETURN(hipErrorInvalidDevice); \
118118
} \
@@ -125,16 +125,16 @@ const char* ihipGetErrorName(hipError_t hip_error);
125125
#define HIP_INIT_VOID() \
126126
{ \
127127
bool status = true; \
128-
std::call_once(hip::g_ihipInitialized, hip::init, &status); \
129-
if (hip::tls.device_ == nullptr && hip::g_devices.size() > 0) { \
130-
hip::tls.device_ = hip::g_devices[0]; \
131-
amd::Os::setPreferredNumaNode(hip::g_devices[0]->devices()[0]->getPreferredNumaNode()); \
128+
std::call_once(hip::g_ihipInitialized, hip::init, &status); \
129+
if (hip::tls.device_ == nullptr && hip::g_devices.size() > 0) { \
130+
hip::tls.device_ = hip::g_devices[0]; \
131+
amd::Os::setPreferredNumaNode(hip::g_devices[0]->devices()[0]->getPreferredNumaNode()); \
132132
} \
133133
}
134134

135135

136136
#define HIP_API_PRINT(...) \
137-
uint64_t startTimeUs=0; \
137+
uint64_t startTimeUs = 0; \
138138
HIPPrintDuration(amd::LOG_INFO, amd::LOG_API, &startTimeUs, \
139139
"%s %s ( %s ) %s", KGRN, \
140140
__func__, ToString( __VA_ARGS__ ).c_str(), KNRM);

0 commit comments

Comments
 (0)