Skip to content

Commit 42e19b9

Browse files
committed
Localize thread Id
1 parent 377769e commit 42e19b9

File tree

13 files changed

+131
-87
lines changed

13 files changed

+131
-87
lines changed

interfaces/cuda/Control.cu

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
#include <cuda.h>
88
#include <iostream>
99
#include <iomanip>
10+
#include <mutex>
1011
#include <sstream>
1112
#include <string>
13+
#include <thread>
1214

1315
#ifdef PROFILING_ENABLED
1416
#include <nvToolsExt.h>
@@ -26,16 +28,19 @@ ConcreteAPI::ConcreteAPI() {
2628
}
2729

2830
void ConcreteAPI::setDevice(int deviceId) {
29-
currentDeviceId = deviceId;
30-
cudaSetDevice(currentDeviceId);
31+
{
32+
std::lock_guard guard(this->apiMutex);
33+
deviceMap[std::this_thread::get_id()] = deviceId;
34+
}
35+
cudaSetDevice(deviceId);
3136
CHECK_ERR;
3237

3338
// Note: the following sets the initial CUDA context
3439
cudaFree(nullptr);
3540
CHECK_ERR;
3641

3742
int result;
38-
cudaDeviceGetAttribute(&result, cudaDevAttrDirectManagedMemAccessFromHost, currentDeviceId);
43+
cudaDeviceGetAttribute(&result, cudaDevAttrDirectManagedMemAccessFromHost, getDeviceId());
3944
usmDefault = result != 0;
4045

4146
status[StatusID::DeviceSelected] = true;
@@ -55,15 +60,15 @@ void ConcreteAPI::initialize() {
5560
cudaEventCreate(&defaultStreamEvent); CHECK_ERR;
5661

5762
int result{0};
58-
cudaDeviceGetAttribute(&result, cudaDevAttrConcurrentManagedAccess, currentDeviceId);
63+
cudaDeviceGetAttribute(&result, cudaDevAttrConcurrentManagedAccess, getDeviceId());
5964
CHECK_ERR;
6065
allowedConcurrentManagedAccess = result != 0;
6166

6267
cudaDeviceGetStreamPriorityRange(&priorityMin, &priorityMax);
6368
CHECK_ERR;
6469

6570
int canCompressProto = 0;
66-
cuDeviceGetAttribute(&canCompressProto, CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED, currentDeviceId);
71+
cuDeviceGetAttribute(&canCompressProto, CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED, getDeviceId());
6772
canCompress = canCompressProto != 0;
6873
}
6974
else {
@@ -97,7 +102,12 @@ int ConcreteAPI::getDeviceId() {
97102
if (!status[StatusID::DeviceSelected]) {
98103
logError() << "Device has not been selected. Please, select device before requesting device Id";
99104
}
100-
return currentDeviceId;
105+
const auto myId = std::this_thread::get_id();
106+
auto findResult = deviceMap.find(myId);
107+
if (findResult == deviceMap.end()) {
108+
logError() << "Thread device context not initialized. Error.";
109+
}
110+
return findResult->second;;
101111
}
102112

103113
unsigned ConcreteAPI::getGlobMemAlignment() {

interfaces/cuda/Copy.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void *devPtr, siz
8080
#endif
8181
}
8282
else if (allowedConcurrentManagedAccess) {
83-
location.id = currentDeviceId;
83+
location.id = getDeviceId();
8484
#if CUDART_VERSION >= 13000
8585
location.type = cudaMemLocationTypeDevice;
8686
#endif

interfaces/cuda/CudaWrappedAPI.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <string>
1010
#include <unordered_map>
1111
#include <vector>
12+
#include <thread>
1213
#include <unordered_set>
1314
#include <cassert>
1415
#include <cstdint>
@@ -105,7 +106,7 @@ class ConcreteAPI : public AbstractAPI {
105106
void createCircularStreamAndEvents();
106107

107108
device::StatusT status{false};
108-
int currentDeviceId{-1};
109+
std::unordered_map<std::thread::id, int> deviceMap;
109110
bool allowedConcurrentManagedAccess{false};
110111

111112
bool usmDefault{false};

interfaces/cuda/Memory.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void *ConcreteAPI::allocGlobMem(size_t size, bool compress) {
7676
std::memset(&prop, 0, sizeof(CUmemAllocationProp));
7777
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
7878
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
79-
prop.location.id = currentDeviceId;
79+
prop.location.id = getDeviceId();
8080
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_GENERIC;
8181

8282
devPtr = driverAllocate(size, prop);
@@ -105,7 +105,7 @@ void *ConcreteAPI::allocUnifiedMem(size_t size, bool compress, Destination hint)
105105
#endif
106106
}
107107
else if (allowedConcurrentManagedAccess) {
108-
location.id = currentDeviceId;
108+
location.id = getDeviceId();
109109
#if CUDART_VERSION >= 13000
110110
location.type = cudaMemLocationTypeDevice;
111111
#endif
@@ -198,7 +198,7 @@ std::string ConcreteAPI::getMemLeaksReport() {
198198
size_t ConcreteAPI::getMaxAvailableMem() {
199199
isFlagSet<DeviceSelected>(status);
200200
cudaDeviceProp property;
201-
cudaGetDeviceProperties(&property, currentDeviceId);
201+
cudaGetDeviceProperties(&property, getDeviceId());
202202
CHECK_ERR;
203203
return property.totalGlobalMem;
204204
}

interfaces/hip/Control.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
#include "utils/env.h"
77
#include <iostream>
88
#include <iomanip>
9+
#include <mutex>
910
#include <sstream>
1011
#include <string>
12+
#include <thread>
1113
#include "hip/hip_runtime.h"
1214

1315
#ifdef PROFILING_ENABLED
@@ -26,16 +28,19 @@ ConcreteAPI::ConcreteAPI() {
2628
}
2729

2830
void ConcreteAPI::setDevice(int deviceId) {
29-
currentDeviceId = deviceId;
30-
hipSetDevice(currentDeviceId);
31+
{
32+
std::lock_guard guard(this->apiMutex);
33+
deviceMap[std::this_thread::get_id()] = deviceId;
34+
}
35+
hipSetDevice(deviceId);
3136
CHECK_ERR;
3237

3338
// Note: the following sets the initial HIP context
3439
hipFree(nullptr);
3540
CHECK_ERR;
3641

3742
hipDeviceProp_t properties{};
38-
hipGetDeviceProperties(&properties, currentDeviceId);
43+
hipGetDeviceProperties(&properties, deviceId);
3944
CHECK_ERR;
4045

4146
// NOTE: hipDeviceGetAttribute internally calls hipGetDeviceProperties; hence it doesn't make sense to use it here
@@ -102,7 +107,12 @@ int ConcreteAPI::getDeviceId() {
102107
if (!status[StatusID::DeviceSelected]) {
103108
logError() << "Device has not been selected. Please, select device before requesting device Id";
104109
}
105-
return currentDeviceId;
110+
const auto myId = std::this_thread::get_id();
111+
auto findResult = deviceMap.find(myId);
112+
if (findResult == deviceMap.end()) {
113+
logError() << "Thread device context not initialized. Error.";
114+
}
115+
return findResult->second;;
106116
}
107117

108118
unsigned ConcreteAPI::getGlobMemAlignment() {

interfaces/hip/Copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void* devPtr, siz
8181
hipStream_t stream = (streamPtr == nullptr) ? 0 : (static_cast<hipStream_t>(streamPtr));
8282
hipMemPrefetchAsync(devPtr,
8383
count,
84-
type == Destination::CurrentDevice ? currentDeviceId : hipCpuDeviceId,
84+
type == Destination::CurrentDevice ? getDeviceId() : hipCpuDeviceId,
8585
stream);
8686
}
8787

interfaces/hip/HipWrappedAPI.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <string>
1515
#include <unordered_map>
1616
#include <vector>
17+
#include <thread>
1718
#include <unordered_set>
1819
#include <cassert>
1920

@@ -103,7 +104,7 @@ class ConcreteAPI : public AbstractAPI {
103104

104105
private:
105106
device::StatusT status{false};
106-
int currentDeviceId{-1};
107+
std::unordered_map<std::thread::id, int> deviceMap;
107108

108109
bool usmDefault{false};
109110

interfaces/hip/Memory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void* ConcreteAPI::allocUnifiedMem(size_t size, bool compress, Destination hint)
3333
CHECK_ERR;
3434
}
3535
else {
36-
hipMemAdvise(devPtr, size, hipMemAdviseSetPreferredLocation, currentDeviceId);
36+
hipMemAdvise(devPtr, size, hipMemAdviseSetPreferredLocation, getDeviceId());
3737
CHECK_ERR;
3838
}
3939
statistics.allocatedMemBytes += size;
@@ -105,7 +105,7 @@ std::string ConcreteAPI::getMemLeaksReport() {
105105
size_t ConcreteAPI::getMaxAvailableMem() {
106106
isFlagSet<DeviceSelected>(status);
107107
hipDeviceProp_t property;
108-
hipGetDeviceProperties(&property, currentDeviceId); CHECK_ERR;
108+
hipGetDeviceProperties(&property, getDeviceId()); CHECK_ERR;
109109
return property.totalGlobalMem;
110110
}
111111

interfaces/sycl/Control.cpp

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <iostream>
1111
#include <string>
12+
#include <thread>
1213

1314
using namespace device;
1415

@@ -42,24 +43,25 @@ void ConcreteAPI::initDevices() {
4243
return compare(c1->queueBuffer.getDefaultQueue().get_device(), c2->queueBuffer.getDefaultQueue().get_device());
4344
});
4445

45-
this->setDevice(this->currentDeviceId);
46+
this->setDevice(0);
4647
this->deviceInitialized = true;
4748
}
4849

4950
void ConcreteAPI::setDevice(int id) {
51+
{
52+
std::lock_guard guard(this->apiMutex);
5053

51-
if (id < 0 || id >= this->getNumDevices()) {
52-
throw std::out_of_range{"Device index out of range"};
53-
}
54+
if (id < 0 || id >= this->getNumDevices()) {
55+
throw std::out_of_range{"Device index out of range"};
56+
}
5457

55-
this->currentDeviceId = id;
56-
auto *next = this->availableDevices[id];
57-
this->currentStatistics = &next->statistics;
58-
this->currentQueueBuffer = &next->queueBuffer;
59-
this->currentDefaultQueue = &this->currentQueueBuffer->getDefaultQueue();
60-
this->currentMemoryToSizeMap = &next->memoryToSizeMap;
58+
if (deviceMap.empty()) {
59+
// only print the first time
60+
printer.printInfo() << "Switched to device: " << this->getDeviceName(id) << " by index " << id;
61+
}
6162

62-
printer.printInfo() << "Switched to device: " << this->getDeviceName(id) << " by index " << id;
63+
deviceMap[std::this_thread::get_id()] = id;
64+
}
6365
}
6466

6567
void ConcreteAPI::initialize() {}
@@ -77,10 +79,7 @@ void ConcreteAPI::finalize() {
7779

7880
this->graphs.clear();
7981

80-
this->currentStatistics = nullptr;
81-
this->currentQueueBuffer = nullptr;
82-
this->currentDefaultQueue = nullptr;
83-
this->currentMemoryToSizeMap = nullptr;
82+
this->deviceMap.clear();
8483

8584
this->m_isFinalized = true;
8685
this->deviceInitialized = false;
@@ -92,15 +91,20 @@ int ConcreteAPI::getDeviceId() {
9291
if (!deviceInitialized) {
9392
logError() << "Device has not been selected. Please, select device before requesting device Id";
9493
}
95-
return currentDeviceId;
94+
const auto myId = std::this_thread::get_id();
95+
auto findResult = deviceMap.find(myId);
96+
if (findResult == deviceMap.end()) {
97+
logError() << "Thread device context not initialized. Error.";
98+
}
99+
return findResult->second;;
96100
}
97101

98102
unsigned int ConcreteAPI::getGlobMemAlignment() {
99-
auto device = this->currentDefaultQueue->get_device();
103+
auto device = this->currentQueueBuffer()->getDefaultQueue()->get_device();
100104
return 128; //ToDo: find attribute; not: device.get_info<info::device::mem_base_addr_align>();
101105
}
102106

103-
void ConcreteAPI::syncDevice() { this->currentQueueBuffer->syncAllQueuesWithHost(); }
107+
void ConcreteAPI::syncDevice() { this->currentQueueBuffer()->syncAllQueuesWithHost(); }
104108

105109
std::string ConcreteAPI::getDeviceInfoAsText(int id) {
106110
if (id < 0 || id >= this->getNumDevices())
@@ -109,7 +113,7 @@ std::string ConcreteAPI::getDeviceInfoAsText(int id) {
109113
auto device = this->availableDevices[id]->queueBuffer.getDefaultQueue().get_device();
110114
return this->getDeviceInfoAsTextInternal(device);
111115
}
112-
std::string ConcreteAPI::getCurrentDeviceInfoAsText() { return this->getDeviceInfoAsText(this->currentDeviceId); }
116+
std::string ConcreteAPI::getCurrentDeviceInfoAsText() { return this->getDeviceInfoAsText(getDeviceId()); }
113117

114118
std::string ConcreteAPI::getDeviceInfoAsTextInternal(sycl::device& dev) {
115119
std::ostringstream info{};
@@ -126,7 +130,7 @@ std::string ConcreteAPI::getDeviceInfoAsTextInternal(sycl::device& dev) {
126130

127131
bool ConcreteAPI::isUnifiedMemoryDefault() {
128132
// suboptimal (i.e. we'd need to query if USM needs to be migrated or not), but there's probably nothing better for now
129-
auto device = this->availableDevices[this->currentDeviceId]->queueBuffer.getDefaultQueue().get_device();
133+
auto device = this->availableDevices[getDeviceId()]->queueBuffer.getDefaultQueue().get_device();
130134
return device.has(sycl::aspect::usm_system_allocations);
131135
}
132136

interfaces/sycl/Copy.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010
using namespace device;
1111

1212
void ConcreteAPI::copyTo(void *dst, const void *src, size_t size) {
13-
this->currentStatistics->explicitlyTransferredDataToDeviceBytes += size;
14-
this->currentDefaultQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
13+
this->currentStatistics().explicitlyTransferredDataToDeviceBytes += size;
14+
this->currentDefaultQueue().submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
1515
}
1616

1717
void ConcreteAPI::copyFrom(void *dst, const void *src, size_t size) {
18-
this->currentStatistics->explicitlyTransferredDataToHostBytes += size;
19-
this->currentDefaultQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
18+
this->currentStatistics().explicitlyTransferredDataToHostBytes += size;
19+
this->currentDefaultQueue().submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
2020
}
2121

2222
void ConcreteAPI::copyBetween(void *dst, const void *src, size_t size) {
23-
this->currentDefaultQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
23+
this->currentDefaultQueue().submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, size); }).wait_and_throw();
2424
}
2525

2626
void ConcreteAPI::copy2dArrayTo(void *dst, size_t dpitch, const void *src, size_t spitch, size_t width, size_t height) {
@@ -46,8 +46,8 @@ void ConcreteAPI::copy2dArrayFrom(void *dst, size_t dpitch, const void *src, siz
4646

4747
void ConcreteAPI::copyToAsync(void *dst, const void *src, size_t count, void *streamPtr) {
4848
auto *targetQueue = (streamPtr != nullptr) ? static_cast<sycl::queue *>(streamPtr) : this->currentDefaultQueue;
49-
if (!this->currentQueueBuffer->exists(targetQueue))
50-
throw std::invalid_argument(getDeviceInfoAsText(currentDeviceId)
49+
if (!this->currentQueueBuffer().exists(targetQueue))
50+
throw std::invalid_argument(getDeviceInfoAsText(getDeviceId())
5151
.append("tried to prefetch usm on a queue that is not known to this device"));
5252

5353
targetQueue->submit([&](sycl::handler &cgh) { cgh.memcpy(dst, src, count); });
@@ -63,8 +63,8 @@ void ConcreteAPI::copyBetweenAsync(void *dst, const void *src, size_t count, voi
6363

6464
void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void *devPtr, size_t count, void *streamPtr) {
6565
auto *asQueue = (streamPtr != nullptr) ? static_cast<sycl::queue *>(streamPtr) : this->currentDefaultQueue;
66-
if (!this->currentQueueBuffer->exists(asQueue))
67-
throw std::invalid_argument(getDeviceInfoAsText(currentDeviceId)
66+
if (!this->currentQueueBuffer().exists(asQueue))
67+
throw std::invalid_argument(getDeviceInfoAsText(getDeviceId())
6868
.append("tried to prefetch usm on a queue that is not known to this device"));
6969

7070
asQueue->submit([&](sycl::handler &cgh) { cgh.prefetch(devPtr, count); });

0 commit comments

Comments
 (0)