Skip to content

Commit 2f06e57

Browse files
committed
Localize thread Id
1 parent 377769e commit 2f06e57

File tree

13 files changed

+110
-82
lines changed

13 files changed

+110
-82
lines changed

interfaces/cuda/Control.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <iomanip>
1010
#include <sstream>
1111
#include <string>
12+
#include <thread>
1213

1314
#ifdef PROFILING_ENABLED
1415
#include <nvToolsExt.h>
@@ -26,16 +27,16 @@ ConcreteAPI::ConcreteAPI() {
2627
}
2728

2829
void ConcreteAPI::setDevice(int deviceId) {
29-
currentDeviceId = deviceId;
30-
cudaSetDevice(currentDeviceId);
30+
deviceMap[std::this_thread::get_id()] = deviceId;
31+
cudaSetDevice(deviceId);
3132
CHECK_ERR;
3233

3334
// Note: the following sets the initial CUDA context
3435
cudaFree(nullptr);
3536
CHECK_ERR;
3637

3738
int result;
38-
cudaDeviceGetAttribute(&result, cudaDevAttrDirectManagedMemAccessFromHost, currentDeviceId);
39+
cudaDeviceGetAttribute(&result, cudaDevAttrDirectManagedMemAccessFromHost, getDeviceId());
3940
usmDefault = result != 0;
4041

4142
status[StatusID::DeviceSelected] = true;
@@ -55,15 +56,15 @@ void ConcreteAPI::initialize() {
5556
cudaEventCreate(&defaultStreamEvent); CHECK_ERR;
5657

5758
int result{0};
58-
cudaDeviceGetAttribute(&result, cudaDevAttrConcurrentManagedAccess, currentDeviceId);
59+
cudaDeviceGetAttribute(&result, cudaDevAttrConcurrentManagedAccess, getDeviceId());
5960
CHECK_ERR;
6061
allowedConcurrentManagedAccess = result != 0;
6162

6263
cudaDeviceGetStreamPriorityRange(&priorityMin, &priorityMax);
6364
CHECK_ERR;
6465

6566
int canCompressProto = 0;
66-
cuDeviceGetAttribute(&canCompressProto, CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED, currentDeviceId);
67+
cuDeviceGetAttribute(&canCompressProto, CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED, getDeviceId());
6768
canCompress = canCompressProto != 0;
6869
}
6970
else {
@@ -97,7 +98,11 @@ int ConcreteAPI::getDeviceId() {
9798
if (!status[StatusID::DeviceSelected]) {
9899
logError() << "Device has not been selected. Please, select device before requesting device Id";
99100
}
100-
return currentDeviceId;
101+
const auto myId = std::this_thread::get_id()
102+
if (deviceMap.find(myId) == deviceMap.end()) {
103+
logError() << "Thread device context not initialized. Error.";
104+
}
105+
return deviceMap[myId];
101106
}
102107

103108
unsigned ConcreteAPI::getGlobMemAlignment() {

interfaces/cuda/Copy.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void *devPtr, siz
8080
#endif
8181
}
8282
else if (allowedConcurrentManagedAccess) {
83-
location.id = currentDeviceId;
83+
location.id = getDeviceId();
8484
#if CUDART_VERSION >= 13000
8585
location.type = cudaMemLocationTypeDevice;
8686
#endif

interfaces/cuda/CudaWrappedAPI.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class ConcreteAPI : public AbstractAPI {
105105
void createCircularStreamAndEvents();
106106

107107
device::StatusT status{false};
108-
int currentDeviceId{-1};
108+
std::unordered_map<std::thread::id, std::int64_t> deviceMap;
109109
bool allowedConcurrentManagedAccess{false};
110110

111111
bool usmDefault{false};

interfaces/cuda/Memory.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void *ConcreteAPI::allocGlobMem(size_t size, bool compress) {
7676
std::memset(&prop, 0, sizeof(CUmemAllocationProp));
7777
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
7878
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
79-
prop.location.id = currentDeviceId;
79+
prop.location.id = getDeviceId();
8080
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_GENERIC;
8181

8282
devPtr = driverAllocate(size, prop);
@@ -105,7 +105,7 @@ void *ConcreteAPI::allocUnifiedMem(size_t size, bool compress, Destination hint)
105105
#endif
106106
}
107107
else if (allowedConcurrentManagedAccess) {
108-
location.id = currentDeviceId;
108+
location.id = getDeviceId();
109109
#if CUDART_VERSION >= 13000
110110
location.type = cudaMemLocationTypeDevice;
111111
#endif
@@ -198,7 +198,7 @@ std::string ConcreteAPI::getMemLeaksReport() {
198198
size_t ConcreteAPI::getMaxAvailableMem() {
199199
isFlagSet<DeviceSelected>(status);
200200
cudaDeviceProp property;
201-
cudaGetDeviceProperties(&property, currentDeviceId);
201+
cudaGetDeviceProperties(&property, getDeviceId());
202202
CHECK_ERR;
203203
return property.totalGlobalMem;
204204
}

interfaces/hip/Control.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <iomanip>
99
#include <sstream>
1010
#include <string>
11+
#include <thread>
1112
#include "hip/hip_runtime.h"
1213

1314
#ifdef PROFILING_ENABLED
@@ -26,16 +27,16 @@ ConcreteAPI::ConcreteAPI() {
2627
}
2728

2829
void ConcreteAPI::setDevice(int deviceId) {
29-
currentDeviceId = deviceId;
30-
hipSetDevice(currentDeviceId);
30+
deviceMap[std::this_thread::get_id()] = deviceId;
31+
hipSetDevice(deviceId);
3132
CHECK_ERR;
3233

3334
// Note: the following sets the initial HIP context
3435
hipFree(nullptr);
3536
CHECK_ERR;
3637

3738
hipDeviceProp_t properties{};
38-
hipGetDeviceProperties(&properties, currentDeviceId);
39+
hipGetDeviceProperties(&properties, deviceId);
3940
CHECK_ERR;
4041

4142
// NOTE: hipDeviceGetAttribute internally calls hipGetDeviceProperties; hence it doesn't make sense to use it here
@@ -102,7 +103,11 @@ int ConcreteAPI::getDeviceId() {
102103
if (!status[StatusID::DeviceSelected]) {
103104
logError() << "Device has not been selected. Please, select device before requesting device Id";
104105
}
105-
return currentDeviceId;
106+
const auto myId = std::this_thread::get_id()
107+
if (deviceMap.find(myId) == deviceMap.end()) {
108+
logError() << "Thread device context not initialized. Error.";
109+
}
110+
return deviceMap[myId];
106111
}
107112

108113
unsigned ConcreteAPI::getGlobMemAlignment() {

interfaces/hip/Copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void* devPtr, siz
8181
hipStream_t stream = (streamPtr == nullptr) ? 0 : (static_cast<hipStream_t>(streamPtr));
8282
hipMemPrefetchAsync(devPtr,
8383
count,
84-
type == Destination::CurrentDevice ? currentDeviceId : hipCpuDeviceId,
84+
type == Destination::CurrentDevice ? getDeviceId() : hipCpuDeviceId,
8585
stream);
8686
}
8787

interfaces/hip/HipWrappedAPI.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class ConcreteAPI : public AbstractAPI {
103103

104104
private:
105105
device::StatusT status{false};
106-
int currentDeviceId{-1};
106+
std::unordered_map<std::thread::id, std::int64_t> deviceMap;
107107

108108
bool usmDefault{false};
109109

interfaces/hip/Memory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void* ConcreteAPI::allocUnifiedMem(size_t size, bool compress, Destination hint)
3333
CHECK_ERR;
3434
}
3535
else {
36-
hipMemAdvise(devPtr, size, hipMemAdviseSetPreferredLocation, currentDeviceId);
36+
hipMemAdvise(devPtr, size, hipMemAdviseSetPreferredLocation, getDeviceId());
3737
CHECK_ERR;
3838
}
3939
statistics.allocatedMemBytes += size;
@@ -105,7 +105,7 @@ std::string ConcreteAPI::getMemLeaksReport() {
105105
size_t ConcreteAPI::getMaxAvailableMem() {
106106
isFlagSet<DeviceSelected>(status);
107107
hipDeviceProp_t property;
108-
hipGetDeviceProperties(&property, currentDeviceId); CHECK_ERR;
108+
hipGetDeviceProperties(&property, getDeviceId()); CHECK_ERR;
109109
return property.totalGlobalMem;
110110
}
111111

interfaces/sycl/Control.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <iostream>
1111
#include <string>
12+
#include <thread>
1213

1314
using namespace device;
1415

@@ -42,24 +43,21 @@ void ConcreteAPI::initDevices() {
4243
return compare(c1->queueBuffer.getDefaultQueue().get_device(), c2->queueBuffer.getDefaultQueue().get_device());
4344
});
4445

45-
this->setDevice(this->currentDeviceId);
46+
this->setDevice(0);
4647
this->deviceInitialized = true;
4748
}
4849

4950
void ConcreteAPI::setDevice(int id) {
50-
5151
if (id < 0 || id >= this->getNumDevices()) {
5252
throw std::out_of_range{"Device index out of range"};
5353
}
5454

55-
this->currentDeviceId = id;
56-
auto *next = this->availableDevices[id];
57-
this->currentStatistics = &next->statistics;
58-
this->currentQueueBuffer = &next->queueBuffer;
59-
this->currentDefaultQueue = &this->currentQueueBuffer->getDefaultQueue();
60-
this->currentMemoryToSizeMap = &next->memoryToSizeMap;
55+
if (deviceMap.empty()) {
56+
// only print the first time
57+
printer.printInfo() << "Switched to device: " << this->getDeviceName(id) << " by index " << id;
58+
}
6159

62-
printer.printInfo() << "Switched to device: " << this->getDeviceName(id) << " by index " << id;
60+
deviceMap[std::this_thread::get_id()] = id;
6361
}
6462

6563
void ConcreteAPI::initialize() {}
@@ -77,10 +75,7 @@ void ConcreteAPI::finalize() {
7775

7876
this->graphs.clear();
7977

80-
this->currentStatistics = nullptr;
81-
this->currentQueueBuffer = nullptr;
82-
this->currentDefaultQueue = nullptr;
83-
this->currentMemoryToSizeMap = nullptr;
78+
this->deviceMap.clear();
8479

8580
this->m_isFinalized = true;
8681
this->deviceInitialized = false;
@@ -92,15 +87,19 @@ int ConcreteAPI::getDeviceId() {
9287
if (!deviceInitialized) {
9388
logError() << "Device has not been selected. Please, select device before requesting device Id";
9489
}
95-
return currentDeviceId;
90+
const auto myId = std::this_thread::get_id()
91+
if (deviceMap.find(myId) == deviceMap.end()) {
92+
logError() << "Thread device context not initialized. Error.";
93+
}
94+
return deviceMap[myId];
9695
}
9796

9897
unsigned int ConcreteAPI::getGlobMemAlignment() {
99-
auto device = this->currentDefaultQueue->get_device();
98+
auto device = this->currentQueueBuffer()->getDefaultQueue()->get_device();
10099
return 128; //ToDo: find attribute; not: device.get_info<info::device::mem_base_addr_align>();
101100
}
102101

103-
void ConcreteAPI::syncDevice() { this->currentQueueBuffer->syncAllQueuesWithHost(); }
102+
void ConcreteAPI::syncDevice() { this->currentQueueBuffer()->syncAllQueuesWithHost(); }
104103

105104
std::string ConcreteAPI::getDeviceInfoAsText(int id) {
106105
if (id < 0 || id >= this->getNumDevices())
@@ -109,7 +108,7 @@ std::string ConcreteAPI::getDeviceInfoAsText(int id) {
109108
auto device = this->availableDevices[id]->queueBuffer.getDefaultQueue().get_device();
110109
return this->getDeviceInfoAsTextInternal(device);
111110
}
112-
std::string ConcreteAPI::getCurrentDeviceInfoAsText() { return this->getDeviceInfoAsText(this->currentDeviceId); }
111+
std::string ConcreteAPI::getCurrentDeviceInfoAsText() { return this->getDeviceInfoAsText(getDeviceId()); }
113112

114113
std::string ConcreteAPI::getDeviceInfoAsTextInternal(sycl::device& dev) {
115114
std::ostringstream info{};
@@ -126,7 +125,7 @@ std::string ConcreteAPI::getDeviceInfoAsTextInternal(sycl::device& dev) {
126125

127126
bool ConcreteAPI::isUnifiedMemoryDefault() {
128127
// suboptimal (i.e. we'd need to query if USM needs to be migrated or not), but there's probably nothing better for now
129-
auto device = this->availableDevices[this->currentDeviceId]->queueBuffer.getDefaultQueue().get_device();
128+
auto device = this->availableDevices[getDeviceId()]->queueBuffer.getDefaultQueue().get_device();
130129
return device.has(sycl::aspect::usm_system_allocations);
131130
}
132131

interfaces/sycl/Copy.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010
using namespace device;
1111

1212
void ConcreteAPI::copyTo(void *dst, const void *src, size_t size) {
13-
this->currentStatistics->explicitlyTransferredDataToDeviceBytes += size;
14-
this->currentDefaultQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
13+
this->currentStatistics().explicitlyTransferredDataToDeviceBytes += size;
14+
this->currentDefaultQueue().submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
1515
}
1616

1717
void ConcreteAPI::copyFrom(void *dst, const void *src, size_t size) {
18-
this->currentStatistics->explicitlyTransferredDataToHostBytes += size;
19-
this->currentDefaultQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
18+
this->currentStatistics().explicitlyTransferredDataToHostBytes += size;
19+
this->currentDefaultQueue().submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
2020
}
2121

2222
void ConcreteAPI::copyBetween(void *dst, const void *src, size_t size) {
23-
this->currentDefaultQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
23+
this->currentDefaultQueue().submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
2424
}
2525

2626
void ConcreteAPI::copy2dArrayTo(void *dst, size_t dpitch, const void *src, size_t spitch, size_t width, size_t height) {
@@ -46,8 +46,8 @@ void ConcreteAPI::copy2dArrayFrom(void *dst, size_t dpitch, const void *src, siz
4646

4747
void ConcreteAPI::copyToAsync(void *dst, const void *src, size_t count, void *streamPtr) {
4848
auto *targetQueue = (streamPtr != nullptr) ? static_cast<sycl::queue *>(streamPtr) : this->currentDefaultQueue;
49-
if (!this->currentQueueBuffer->exists(targetQueue))
50-
throw std::invalid_argument(getDeviceInfoAsText(currentDeviceId)
49+
if (!this->currentQueueBuffer().exists(targetQueue))
50+
throw std::invalid_argument(getDeviceInfoAsText(getDeviceId())
5151
.append("tried to prefetch usm on a queue that is not known to this device"));
5252

5353
targetQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, count); });
@@ -63,8 +63,8 @@ void ConcreteAPI::copyBetweenAsync(void *dst, const void *src, size_t count, voi
6363

6464
void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void *devPtr, size_t count, void *streamPtr) {
6565
auto *asQueue = (streamPtr != nullptr) ? static_cast<sycl::queue *>(streamPtr) : this->currentDefaultQueue;
66-
if (!this->currentQueueBuffer->exists(asQueue))
67-
throw std::invalid_argument(getDeviceInfoAsText(currentDeviceId)
66+
if (!this->currentQueueBuffer().exists(asQueue))
67+
throw std::invalid_argument(getDeviceInfoAsText(getDeviceId())
6868
.append("tried to prefetch usm on a queue that is not known to this device"));
6969

7070
asQueue->submit([&](sycl::handler &cgh) { cgh.prefetch(devPtr, count); });

0 commit comments

Comments
 (0)