Skip to content

Commit 65b69bf

Browse files
authored
changes for hugepages backed host buffer for larger allocations (#1841)
1 parent 07925ec commit 65b69bf

File tree

3 files changed

+87
-18
lines changed

3 files changed

+87
-18
lines changed

src/include/alloc.h

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
#include <stdlib.h>
1919
#include <string.h>
2020
#include "rccl_vars.h"
21+
#include <unordered_map>
22+
#include <mutex>
23+
24+
#define RCCL_HP_MIN_SIZE 2097152
2125

2226
#if CUDART_VERSION >= 11030
2327
#include <cuda.h>
@@ -31,6 +35,9 @@ constexpr size_t ncclSizeOfT() { return sizeof(T); }
3135
template<>
3236
constexpr size_t ncclSizeOfT<void>() { return 1; }
3337

38+
extern std::unordered_map<void*, size_t> hugepageAllocs;
39+
extern std::mutex hugepageAllocsMutex;
40+
3441
#if CUDART_VERSION >= 12020
3542

3643
static inline ncclResult_t ncclCuMemHostAlloc(void** ptr, CUmemGenericAllocationHandle *handlep, size_t size) {
@@ -105,43 +112,100 @@ static inline ncclResult_t ncclCuMemHostFree(void* ptr) {
105112
}
106113

107114
#endif /* CUDART_VERSION >= 12020 */
108-
109115
template <typename T>
110-
ncclResult_t ncclCudaHostCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
116+
ncclResult_t ncclCudaHostCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line, int hp_request=0 ) {
111117
ncclResult_t result = ncclSuccess;
112118
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
113119
*ptr = nullptr;
120+
size_t size = nelem * ncclSizeOfT<T>();
121+
114122
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
115123
int managed = 0;
124+
int huge=0;
116125
CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0));
126+
117127
if (nelem > 0) {
118128
if (managed) {
119129
#if defined(HIP_UNCACHED_MEMORY)
120-
CUDACHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem*ncclSizeOfT<T>(), hipDeviceMallocUncached), result, finish);
130+
CUDACHECKGOTO(hipExtMallocWithFlags((void**)ptr, size, hipDeviceMallocUncached), result, finish);
121131
#else
122-
CUDACHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem*ncclSizeOfT<T>(), hipDeviceMallocFinegrained), result, finish);
132+
CUDACHECKGOTO(hipExtMallocWithFlags((void**)ptr, size, hipDeviceMallocFinegrained), result, finish);
123133
#endif
124-
} else
134+
} else {
135+
if (hp_request) {
136+
if (size < RCCL_HP_MIN_SIZE) {
137+
WARN("small size : forcing back to hipHostMalloc");
125138
#if defined(HIP_HOST_UNCACHED_MEMORY)
126-
CUDACHECKGOTO(hipHostMalloc(ptr, nelem*ncclSizeOfT<T>(), cudaHostAllocMapped | hipHostMallocUncached), result, finish);
139+
CUDACHECKGOTO(hipHostMalloc(ptr, size, cudaHostAllocMapped | hipHostMallocUncached), result, finish);
127140
#else
128-
CUDACHECKGOTO(hipHostMalloc(ptr, nelem*ncclSizeOfT<T>(), cudaHostAllocMapped), result, finish);
141+
CUDACHECKGOTO(hipHostMalloc(ptr, size, cudaHostAllocMapped), result, finish);
129142
#endif
130-
memset(*ptr, 0, nelem*ncclSizeOfT<T>());
143+
memset(*ptr, 0, size);
144+
} else {
145+
// Hugepage allocation via mmap
146+
void* hostPtr = mmap(NULL, size, PROT_READ | PROT_WRITE,
147+
MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0);
148+
if (hostPtr == MAP_FAILED) {
149+
WARN("Hugepage allocation failed. Falling back to hipHostMalloc");
150+
#if defined(HIP_HOST_UNCACHED_MEMORY)
151+
CUDACHECKGOTO(hipHostMalloc(ptr, size, cudaHostAllocMapped | hipHostMallocUncached), result, finish);
152+
#else
153+
CUDACHECKGOTO(hipHostMalloc(ptr, size, cudaHostAllocMapped), result, finish);
154+
#endif
155+
memset(*ptr, 0, size);
156+
} else {
157+
memset(hostPtr, 0, size);
158+
CUDACHECKGOTO(hipHostRegister(hostPtr, size, hipHostRegisterMapped), result, finish);
159+
void* devPtr = nullptr;
160+
CUDACHECKGOTO(hipHostGetDevicePointer(&devPtr, hostPtr, 0), result, finish);
161+
*ptr = reinterpret_cast<T*>(hostPtr);
162+
INFO(NCCL_ALLOC, "Cuda Host Alloc Size done using hugepages");
163+
huge=1;
164+
std::lock_guard<std::mutex> lock(hugepageAllocsMutex);
165+
hugepageAllocs[hostPtr] = size;
166+
for (auto &kv : hugepageAllocs) INFO(NCCL_ALLOC, "updated Hugepage alloc ptr %p size %zu", kv.first, kv.second);
167+
}
168+
}
169+
} else {
170+
#if defined(HIP_HOST_UNCACHED_MEMORY)
171+
CUDACHECKGOTO(hipHostMalloc(ptr, size, cudaHostAllocMapped | hipHostMallocUncached), result, finish);
172+
#else
173+
CUDACHECKGOTO(hipHostMalloc(ptr, size, cudaHostAllocMapped), result, finish);
174+
#endif
175+
memset(*ptr, 0, size);
176+
}
177+
}
131178
}
179+
132180
finish:
133181
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
134-
if (*ptr == nullptr && nelem > 0) WARN("Failed to CUDA host alloc %ld bytes", nelem*ncclSizeOfT<T>());
135-
INFO(NCCL_ALLOC, "%s:%d Cuda Host Alloc Size %ld pointer %p", filefunc, line, nelem*ncclSizeOfT<T>(), *ptr);
182+
if (*ptr == nullptr && nelem > 0) WARN("Failed to CUDA host alloc %ld bytes", size);
183+
INFO(NCCL_ALLOC, "%s:%d Cuda Host Alloc Size %ld pointer %p hp_request %d managed %d hugepage_alloc %d", filefunc, line, size, *ptr, hp_request, managed, huge);
136184
return result;
137185
}
138186

139-
static inline ncclResult_t ncclCudaHostFree(void* ptr) {
187+
188+
static inline ncclResult_t ncclCudaHostFree(void* ptr, size_t alloc_size=0, int hp_request=0) {
189+
if (hp_request) {
190+
if (alloc_size > 0) {
191+
std::lock_guard<std::mutex> lock(hugepageAllocsMutex);
192+
// for (auto &kv : hugepageAllocs) INFO(NCCL_ALLOC, "Hugepage alloc ptr %p size %zu", kv.first, kv.second);
193+
auto it = hugepageAllocs.find(ptr);
194+
if (it != hugepageAllocs.end()) {
195+
// INFO(NCCL_ALLOC, "%s:%d Cuda Host HugePage unmap size %ld pointer %p app_tracked_size %ld", __FILE__, __LINE__, it->second, ptr, alloc_size);
196+
hipHostUnregister(ptr);
197+
munmap(ptr, it->second);
198+
hugepageAllocs.erase(it);
199+
return ncclSuccess;
200+
}
201+
}
202+
INFO(NCCL_ALLOC, "Cudafree being done to %p, size=%ld", ptr,alloc_size);
203+
}
140204
CUDACHECK(cudaFreeHost(ptr));
141205
return ncclSuccess;
142206
}
143207

144-
#define ncclCudaHostCalloc(...) ncclCudaHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
208+
#define ncclCudaHostCalloc(...) ncclCudaHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__, 0)
145209

146210
template <typename T>
147211
ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {

src/init.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", NCCL_CONFIG_UNDEF_INT);
9595

9696
struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {};
9797
static ncclResult_t commReclaim(ncclComm_t comm);
98+
std::unordered_map<void*, size_t> hugepageAllocs;
99+
std::mutex hugepageAllocsMutex;
98100

99101
#ifdef ENABLE_MSCCLPP
100102
size_t std::hash<ncclUniqueId>::operator ()(const ncclUniqueId& uniqueId) const noexcept {

src/transport/net.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ static ncclResult_t canConnect(int* ret, struct ncclComm* comm, struct ncclTopoG
179179
NCCL_PARAM(NetSharedBuffers, "NET_SHARED_BUFFERS", -2);
180180
NCCL_PARAM(NetSharedComms, "NET_SHARED_COMMS", 1);
181181

182+
RCCL_PARAM(NetHostBufferHugePageAlloc, "NET_HOST_BUFFER_HUGE_PAGE_ALLOC", 0);
182183
#if defined(HIP_CONTIGUOUS_MEMORY)
183184
RCCL_PARAM(NetContiguousMem, "NET_CONTIGUOUS_MEM", 0);
184185
#endif
@@ -602,7 +603,7 @@ static ncclResult_t sharedNetBuffersInit(struct ncclProxyState* proxyState, int
602603
}
603604
}
604605
if (!cuda && state->hostBuff == NULL) {
605-
NCCLCHECK(ncclCudaHostCalloc(&state->hostBuff, state->size));
606+
NCCLCHECK(ncclCudaHostCallocDebug(&state->hostBuff, state->size, __FILE__, __LINE__, rcclParamNetHostBufferHugePageAlloc()));
606607
}
607608
if (cpuPtr) *cpuPtr = cuda ? state->cudaBuff : state->hostBuff;
608609
if (gpuPtr) *gpuPtr = (cpuPtr && sameProcess) ? *cpuPtr : NULL;
@@ -631,7 +632,9 @@ static ncclResult_t sharedNetBuffersDestroy(struct ncclProxyState* proxyState, i
631632
}
632633
NCCLCHECK(ncclCudaFree(state->cudaBuff));
633634
}
634-
if (state->hostBuff) NCCLCHECK(ncclCudaHostFree(state->hostBuff));
635+
if (state->hostBuff) {
636+
NCCLCHECK(ncclCudaHostFree(state->hostBuff, (state->size)*(sizeof(int64_t)), rcclParamNetHostBufferHugePageAlloc()));
637+
}
635638
}
636639

637640
if (peer->send.refcount || peer->recv.refcount) return ncclSuccess;
@@ -888,7 +891,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
888891
}
889892
}
890893
if (map->sameProcess) {
891-
NCCLCHECK(ncclCudaHostCalloc(&map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size));
894+
NCCLCHECK(ncclCudaHostCallocDebug(&map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size, __FILE__, __LINE__, rcclParamNetHostBufferHugePageAlloc()));
892895
map->mems[NCCL_NET_MAP_HOSTMEM].gpuPtr = map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr;
893896
} else {
894897
NCCLCHECK(netCreateShm(proxyState, map->mems+NCCL_NET_MAP_HOSTMEM));
@@ -1090,7 +1093,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
10901093
map->mems[NCCL_NET_MAP_DEVMEM].cpuPtr = map->mems[NCCL_NET_MAP_DEVMEM].gpuPtr;
10911094
}
10921095
}
1093-
NCCLCHECK(ncclCudaHostCalloc(&map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size));
1096+
NCCLCHECK(ncclCudaHostCallocDebug(&map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, map->mems[NCCL_NET_MAP_HOSTMEM].size, __FILE__, __LINE__, rcclParamNetHostBufferHugePageAlloc()));
10941097
map->mems[NCCL_NET_MAP_HOSTMEM].gpuPtr = map->mems[NCCL_NET_MAP_HOSTMEM].cpuPtr;
10951098
if (ncclGdrCopy && map->sameProcess) {
10961099
uint64_t *cpuPtr, *gpuPtr;
@@ -1165,7 +1168,7 @@ static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct
11651168
}
11661169
struct connectMapMem* mems = resources->map.mems;
11671170
if (resources->map.sameProcess) {
1168-
NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr));
1171+
NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, (mems[NCCL_NET_MAP_HOSTMEM].size)*(sizeof(int)), rcclParamNetHostBufferHugePageAlloc()));
11691172
} else {
11701173
NCCLCHECK(ncclShmIpcClose(&mems[NCCL_NET_MAP_HOSTMEM].createDesc));
11711174
}
@@ -1209,7 +1212,7 @@ static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct
12091212
}
12101213
}
12111214
struct connectMapMem* mems = resources->map.mems;
1212-
NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr));
1215+
NCCLCHECK(ncclCudaHostFree(mems[NCCL_NET_MAP_HOSTMEM].cpuPtr, (mems[NCCL_NET_MAP_HOSTMEM].size)*(sizeof(int)), rcclParamNetHostBufferHugePageAlloc()));
12131216
NCCLCHECK(ncclCudaFree(mems[NCCL_NET_MAP_DEVMEM].cpuPtr));
12141217
if (!resources->map.sameProcess || ncclCuMemEnable()) {
12151218
// cuMem API support

0 commit comments

Comments
 (0)