Skip to content

Commit 6d73b9c

Browse files
authored
Merge pull request #53 from SeisSol/davschneller/memory-fixes
Localize thread Id
2 parents 2ddd358 + e405805 commit 6d73b9c

File tree

17 files changed

+176
-175
lines changed

17 files changed

+176
-175
lines changed

interfaces/common/Common.h

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -50,49 +50,5 @@ constexpr auto mapPercentage(int minval, int maxval, double value) {
5050
return std::max(std::min(static_cast<int>(std::floor(transformed)), maxval), minval);
5151
}
5252

53-
class InfoPrinter {
54-
public:
55-
struct InfoPrinterLine {
56-
std::shared_ptr<std::ostringstream> stream;
57-
InfoPrinter& printer;
58-
59-
InfoPrinterLine(InfoPrinter& printer) : stream(std::make_shared<std::ostringstream>()), printer(printer) {}
60-
61-
template<typename T>
62-
InfoPrinterLine& operator <<(const T& data) {
63-
*stream << data;
64-
return *this;
65-
}
66-
67-
~InfoPrinterLine() {
68-
if (printer.rank < 0) {
69-
printer.stringCache.emplace_back(stream->str());
70-
}
71-
else {
72-
logInfo() << stream->str().c_str();
73-
}
74-
stream = nullptr;
75-
}
76-
};
77-
78-
InfoPrinterLine printInfo() {
79-
return InfoPrinterLine(*this);
80-
}
81-
82-
void setRank(int rank) {
83-
this->rank = rank;
84-
85-
for (const auto& string : stringCache) {
86-
logInfo() << string;
87-
}
88-
89-
stringCache.resize(0);
90-
}
91-
92-
private:
93-
int rank{-1};
94-
std::vector<std::string> stringCache;
95-
};
96-
9753

9854
#endif // SEISSOLDEVICE_INTERFACES_COMMON_COMMON_H_

interfaces/cuda/Control.cu

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
#include <cuda.h>
88
#include <iostream>
99
#include <iomanip>
10+
#include <mutex>
1011
#include <sstream>
1112
#include <string>
13+
#include <thread>
1214

1315
#ifdef PROFILING_ENABLED
1416
#include <nvToolsExt.h>
@@ -17,6 +19,16 @@
1719
#include "CudaWrappedAPI.h"
1820
#include "Internals.h"
1921

22+
namespace {
23+
// `static` is a bit out of place here; but we treat the whole class as an effective singleton anyways
24+
25+
#ifdef DEVICE_CONTEXT_GLOBAL
26+
int currentDeviceId = 0;
27+
#else
28+
thread_local int currentDeviceId = 0;
29+
#endif
30+
}
31+
2032
using namespace device;
2133

2234
ConcreteAPI::ConcreteAPI() {
@@ -26,18 +38,16 @@ ConcreteAPI::ConcreteAPI() {
2638
}
2739

2840
void ConcreteAPI::setDevice(int deviceId) {
41+
2942
currentDeviceId = deviceId;
30-
cudaSetDevice(currentDeviceId);
43+
44+
cudaSetDevice(deviceId);
3145
CHECK_ERR;
3246

3347
// Note: the following sets the initial CUDA context
3448
cudaFree(nullptr);
3549
CHECK_ERR;
3650

37-
int result;
38-
cudaDeviceGetAttribute(&result, cudaDevAttrDirectManagedMemAccessFromHost, currentDeviceId);
39-
usmDefault = result != 0;
40-
4151
status[StatusID::DeviceSelected] = true;
4252
}
4353

@@ -52,18 +62,20 @@ void ConcreteAPI::initialize() {
5262
if (!status[StatusID::InterfaceInitialized]) {
5363
status[StatusID::InterfaceInitialized] = true;
5464
cudaStreamCreateWithFlags(&defaultStream, cudaStreamNonBlocking); CHECK_ERR;
55-
cudaEventCreate(&defaultStreamEvent); CHECK_ERR;
5665

5766
int result{0};
58-
cudaDeviceGetAttribute(&result, cudaDevAttrConcurrentManagedAccess, currentDeviceId);
67+
cudaDeviceGetAttribute(&result, cudaDevAttrConcurrentManagedAccess, getDeviceId());
5968
CHECK_ERR;
6069
allowedConcurrentManagedAccess = result != 0;
6170

71+
cudaDeviceGetAttribute(&result, cudaDevAttrDirectManagedMemAccessFromHost, getDeviceId());
72+
usmDefault = result != 0;
73+
6274
cudaDeviceGetStreamPriorityRange(&priorityMin, &priorityMax);
6375
CHECK_ERR;
6476

6577
int canCompressProto = 0;
66-
cuDeviceGetAttribute(&canCompressProto, CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED, currentDeviceId);
78+
cuDeviceGetAttribute(&canCompressProto, CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED, getDeviceId());
6779
canCompress = canCompressProto != 0;
6880
}
6981
else {
@@ -74,9 +86,8 @@ void ConcreteAPI::initialize() {
7486
void ConcreteAPI::finalize() {
7587
if (status[StatusID::InterfaceInitialized]) {
7688
cudaStreamDestroy(defaultStream); CHECK_ERR;
77-
cudaEventDestroy(defaultStreamEvent); CHECK_ERR;
7889
if (!genericStreams.empty()) {
79-
printer.printInfo() << "DEVICE::WARNING:" << genericStreams.size()
90+
logInfo() << "DEVICE::WARNING:" << genericStreams.size()
8091
<< "device generic stream(s) were not deleted.";
8192
for (auto stream : genericStreams) {
8293
cudaStreamDestroy(stream); CHECK_ERR;
@@ -186,6 +197,6 @@ void ConcreteAPI::popLastProfilingMark() {
186197
}
187198

188199
void ConcreteAPI::setupPrinting(int rank) {
189-
printer.setRank(rank);
200+
190201
}
191202

interfaces/cuda/Copy.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void *devPtr, siz
8080
#endif
8181
}
8282
else if (allowedConcurrentManagedAccess) {
83-
location.id = currentDeviceId;
83+
location.id = getDeviceId();
8484
#if CUDART_VERSION >= 13000
8585
location.type = cudaMemLocationTypeDevice;
8686
#endif

interfaces/cuda/CudaWrappedAPI.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <string>
1010
#include <unordered_map>
1111
#include <vector>
12+
#include <thread>
1213
#include <unordered_set>
1314
#include <cassert>
1415
#include <cstdint>
@@ -105,14 +106,15 @@ class ConcreteAPI : public AbstractAPI {
105106
void createCircularStreamAndEvents();
106107

107108
device::StatusT status{false};
108-
int currentDeviceId{-1};
109+
110+
std::vector<cudaDeviceProp> properties;
111+
109112
bool allowedConcurrentManagedAccess{false};
110113

111114
bool usmDefault{false};
112115
bool canCompress{false};
113116

114117
cudaStream_t defaultStream{nullptr};
115-
cudaEvent_t defaultStreamEvent{};
116118

117119
std::unordered_set<cudaStream_t> genericStreams{};
118120

@@ -128,7 +130,6 @@ class ConcreteAPI : public AbstractAPI {
128130
std::unordered_map<void *, size_t> memToSizeMap{{nullptr, 0}};
129131

130132
int priorityMin, priorityMax;
131-
InfoPrinter printer;
132133

133134
std::unordered_map<void *, void *> allocationProperties;
134135
};

interfaces/cuda/Memory.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void *ConcreteAPI::allocGlobMem(size_t size, bool compress) {
7676
std::memset(&prop, 0, sizeof(CUmemAllocationProp));
7777
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
7878
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
79-
prop.location.id = currentDeviceId;
79+
prop.location.id = getDeviceId();
8080
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_GENERIC;
8181

8282
devPtr = driverAllocate(size, prop);
@@ -105,7 +105,7 @@ void *ConcreteAPI::allocUnifiedMem(size_t size, bool compress, Destination hint)
105105
#endif
106106
}
107107
else if (allowedConcurrentManagedAccess) {
108-
location.id = currentDeviceId;
108+
location.id = getDeviceId();
109109
#if CUDART_VERSION >= 13000
110110
location.type = cudaMemLocationTypeDevice;
111111
#endif
@@ -198,7 +198,7 @@ std::string ConcreteAPI::getMemLeaksReport() {
198198
size_t ConcreteAPI::getMaxAvailableMem() {
199199
isFlagSet<DeviceSelected>(status);
200200
cudaDeviceProp property;
201-
cudaGetDeviceProperties(&property, currentDeviceId);
201+
cudaGetDeviceProperties(&property, getDeviceId());
202202
CHECK_ERR;
203203
return property.totalGlobalMem;
204204
}

interfaces/hip/Control.cpp

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
#include "utils/env.h"
77
#include <iostream>
88
#include <iomanip>
9+
#include <mutex>
910
#include <sstream>
1011
#include <string>
12+
#include <thread>
1113
#include "hip/hip_runtime.h"
1214

1315
#ifdef PROFILING_ENABLED
@@ -19,40 +21,31 @@
1921

2022
using namespace device;
2123

24+
namespace {
25+
#ifdef DEVICE_CONTEXT_GLOBAL
26+
int currentDeviceId = 0;
27+
#else
28+
thread_local int currentDeviceId = 0;
29+
#endif
30+
}
31+
2232
ConcreteAPI::ConcreteAPI() {
2333
hipInit(0);
2434
CHECK_ERR;
2535
status[StatusID::DriverApiInitialized] = true;
2636
}
2737

2838
void ConcreteAPI::setDevice(int deviceId) {
39+
2940
currentDeviceId = deviceId;
30-
hipSetDevice(currentDeviceId);
41+
42+
hipSetDevice(deviceId);
3143
CHECK_ERR;
3244

3345
// Note: the following sets the initial HIP context
3446
hipFree(nullptr);
3547
CHECK_ERR;
3648

37-
hipDeviceProp_t properties{};
38-
hipGetDeviceProperties(&properties, currentDeviceId);
39-
CHECK_ERR;
40-
41-
// NOTE: hipDeviceGetAttribute internally calls hipGetDeviceProperties; hence it doesn't make sense to use it here
42-
43-
if constexpr (HIP_VERSION >= 60200000) {
44-
// cf. https://rocm.docs.amd.com/en/docs-6.2.0/about/release-notes.html
45-
// (before 6.2.0, the flag hipDeviceAttributePageableMemoryAccessUsesHostPageTables had effectively the same effect)
46-
// (cf. https://github.com/ROCm/clr/commit/7d5b4a8f7a7d34f008d65277f8aae4c98a6da375#diff-596cd550f7fdef76b39f1b7b179b20128313dd9cc9ec662b2eae562efa2b7f33L405 )
47-
usmDefault = properties.integrated != 0;
48-
}
49-
else {
50-
usmDefault = properties.directManagedMemAccessFromHost != 0 && properties.pageableMemoryAccessUsesHostPageTables != 0;
51-
}
52-
53-
hipDeviceGetStreamPriorityRange(&priorityMin, &priorityMax);
54-
CHECK_ERR;
55-
5649
status[StatusID::DeviceSelected] = true;
5750
}
5851

@@ -68,7 +61,25 @@ void ConcreteAPI::initialize() {
6861
if (!status[StatusID::InterfaceInitialized]) {
6962
status[StatusID::InterfaceInitialized] = true;
7063
hipStreamCreateWithFlags(&defaultStream, hipStreamNonBlocking); CHECK_ERR;
71-
hipEventCreate(&defaultStreamEvent); CHECK_ERR;
64+
65+
hipDeviceProp_t properties{};
66+
hipGetDeviceProperties(&properties, getDeviceId());
67+
CHECK_ERR;
68+
69+
// NOTE: hipDeviceGetAttribute internally calls hipGetDeviceProperties; hence it doesn't make sense to use it here
70+
71+
if constexpr (HIP_VERSION >= 60200000) {
72+
// cf. https://rocm.docs.amd.com/en/docs-6.2.0/about/release-notes.html
73+
// (before 6.2.0, the flag hipDeviceAttributePageableMemoryAccessUsesHostPageTables had effectively the same effect)
74+
// (cf. https://github.com/ROCm/clr/commit/7d5b4a8f7a7d34f008d65277f8aae4c98a6da375#diff-596cd550f7fdef76b39f1b7b179b20128313dd9cc9ec662b2eae562efa2b7f33L405 )
75+
usmDefault = properties.integrated != 0;
76+
}
77+
else {
78+
usmDefault = properties.directManagedMemAccessFromHost != 0 && properties.pageableMemoryAccessUsesHostPageTables != 0;
79+
}
80+
81+
hipDeviceGetStreamPriorityRange(&priorityMin, &priorityMax);
82+
CHECK_ERR;
7283
}
7384
else {
7485
logWarning() << "Device Interface has already been initialized";
@@ -78,9 +89,8 @@ void ConcreteAPI::initialize() {
7889
void ConcreteAPI::finalize() {
7990
if (status[StatusID::InterfaceInitialized]) {
8091
hipStreamDestroy(defaultStream); CHECK_ERR;
81-
hipEventDestroy(defaultStreamEvent); CHECK_ERR;
8292
if (!genericStreams.empty()) {
83-
printer.printInfo() << "DEVICE::WARNING:" << genericStreams.size()
93+
logInfo() << "DEVICE::WARNING:" << genericStreams.size()
8494
<< "device generic stream(s) were not deleted.";
8595
for (auto stream : genericStreams) {
8696
hipStreamDestroy(stream); CHECK_ERR;
@@ -189,6 +199,6 @@ void ConcreteAPI::popLastProfilingMark() {
189199
}
190200

191201
void ConcreteAPI::setupPrinting(int rank) {
192-
printer.setRank(rank);
202+
193203
}
194204

interfaces/hip/Copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void ConcreteAPI::prefetchUnifiedMemTo(Destination type, const void* devPtr, siz
8181
hipStream_t stream = (streamPtr == nullptr) ? 0 : (static_cast<hipStream_t>(streamPtr));
8282
hipMemPrefetchAsync(devPtr,
8383
count,
84-
type == Destination::CurrentDevice ? currentDeviceId : hipCpuDeviceId,
84+
type == Destination::CurrentDevice ? getDeviceId() : hipCpuDeviceId,
8585
stream);
8686
}
8787

interfaces/hip/HipWrappedAPI.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <string>
1515
#include <unordered_map>
1616
#include <vector>
17+
#include <thread>
1718
#include <unordered_set>
1819
#include <cassert>
1920

@@ -103,12 +104,10 @@ class ConcreteAPI : public AbstractAPI {
103104

104105
private:
105106
device::StatusT status{false};
106-
int currentDeviceId{-1};
107107

108108
bool usmDefault{false};
109109

110110
hipStream_t defaultStream{nullptr};
111-
hipEvent_t defaultStreamEvent{};
112111

113112
std::unordered_set<hipStream_t> genericStreams{};
114113

@@ -124,7 +123,6 @@ class ConcreteAPI : public AbstractAPI {
124123
std::unordered_map<void *, size_t> memToSizeMap{{nullptr, 0}};
125124

126125
int priorityMin, priorityMax;
127-
InfoPrinter printer;
128126
};
129127
} // namespace device
130128

interfaces/hip/Memory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void* ConcreteAPI::allocUnifiedMem(size_t size, bool compress, Destination hint)
3333
CHECK_ERR;
3434
}
3535
else {
36-
hipMemAdvise(devPtr, size, hipMemAdviseSetPreferredLocation, currentDeviceId);
36+
hipMemAdvise(devPtr, size, hipMemAdviseSetPreferredLocation, getDeviceId());
3737
CHECK_ERR;
3838
}
3939
statistics.allocatedMemBytes += size;
@@ -105,7 +105,7 @@ std::string ConcreteAPI::getMemLeaksReport() {
105105
size_t ConcreteAPI::getMaxAvailableMem() {
106106
isFlagSet<DeviceSelected>(status);
107107
hipDeviceProp_t property;
108-
hipGetDeviceProperties(&property, currentDeviceId); CHECK_ERR;
108+
hipGetDeviceProperties(&property, getDeviceId()); CHECK_ERR;
109109
return property.totalGlobalMem;
110110
}
111111

0 commit comments

Comments
 (0)