Skip to content

Commit 250e476

Browse files
committed
UCT/CUDA_IPC: Support VMM with multiple memory allocations
Handle CUDA VMM allocations spanning multiple cuMemCreate chunks by discovering all chunks, exporting their fabric handles into a GPU metadata buffer, and sharing that buffer's fabric handle via the rkey. On the receiver, the metadata is fetched and a persistent contiguous VA mapping is created by importing each chunk individually. Address translation for put/get uses this mapping directly.
1 parent 34243b1 commit 250e476

File tree

8 files changed

+852
-30
lines changed

8 files changed

+852
-30
lines changed

src/uct/cuda/cuda_ipc/cuda_ipc.inl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,36 @@ uct_cuda_ipc_check_and_pop_ctx(int is_ctx_pushed)
8383
}
8484

8585
static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_ipc_get_remote_address(
86-
uct_cuda_ipc_extended_rkey_t *rkey, uint64_t raddr, CUdevice cu_dev,
86+
uct_cuda_ipc_unpacked_rkey_t *rkey, uint64_t raddr, CUdevice cu_dev,
8787
void **laddr_p, void **base_addr_p)
8888
{
8989
ucs_status_t status;
9090
ptrdiff_t offset;
9191
void *mapped_addr;
9292

93-
status = uct_cuda_ipc_map_memhandle(rkey, cu_dev, &mapped_addr,
93+
#if HAVE_CUDA_FABRIC
94+
if (rkey->super.super.ph.handle_type ==
95+
UCT_CUDA_IPC_KEY_HANDLE_TYPE_VMM_MULTI) {
96+
offset = raddr - rkey->super.super.d_bptr;
97+
ucs_assertv(offset <= rkey->super.super.b_len,
98+
"offset:%ld b_len:%lu", offset, rkey->super.super.b_len);
99+
*laddr_p = (void*)(rkey->contig_va + offset);
100+
if (base_addr_p != NULL) {
101+
*base_addr_p = NULL;
102+
}
103+
return UCS_OK;
104+
}
105+
#endif
106+
107+
status = uct_cuda_ipc_map_memhandle(&rkey->super, cu_dev, &mapped_addr,
94108
UCS_LOG_LEVEL_ERROR);
95109
if (ucs_unlikely(status != UCS_OK)) {
96110
return status;
97111
}
98112

99-
offset = UCS_PTR_BYTE_DIFF(rkey->super.d_bptr, raddr);
100-
ucs_assertv(offset <= rkey->super.b_len, "offset:%ld b_len:%lu", offset,
101-
rkey->super.b_len);
113+
offset = UCS_PTR_BYTE_DIFF(rkey->super.super.d_bptr, raddr);
114+
ucs_assertv(offset <= rkey->super.super.b_len, "offset:%ld b_len:%lu",
115+
offset, rkey->super.super.b_len);
102116
*laddr_p = UCS_PTR_BYTE_OFFSET(mapped_addr, offset);
103117
if (base_addr_p != NULL) {
104118
*base_addr_p = mapped_addr;

src/uct/cuda/cuda_ipc/cuda_ipc_cache.c

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,6 @@ uct_cuda_ipc_open_memhandle_legacy(CUipcMemHandle memh, CUdevice cu_dev,
297297
}
298298

299299
#if HAVE_CUDA_FABRIC
300-
static void
301-
uct_cuda_ipc_init_access_desc(CUmemAccessDesc *access_desc, CUdevice cu_dev)
302-
{
303-
access_desc->location.type = CU_MEM_LOCATION_TYPE_DEVICE;
304-
access_desc->flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
305-
access_desc->location.id = cu_dev;
306-
}
307-
308300
static ucs_status_t
309301
uct_cuda_ipc_open_memhandle_vmm(const uct_cuda_ipc_rkey_t *key, CUdevice cu_dev,
310302
CUdeviceptr *mapped_addr,

src/uct/cuda/cuda_ipc/cuda_ipc_ep.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,8 @@ uct_cuda_ipc_post_cuda_async_copy(uct_ep_h tl_ep, uint64_t remote_addr,
120120
return status;
121121
}
122122

123-
status = uct_cuda_ipc_get_remote_address(&key->super, remote_addr,
124-
cuda_device, &mapped_rem_addr,
125-
&mapped_addr);
123+
status = uct_cuda_ipc_get_remote_address(key, remote_addr, cuda_device,
124+
&mapped_rem_addr, &mapped_addr);
126125
if (ucs_unlikely(status != UCS_OK)) {
127126
goto out;
128127
}

src/uct/cuda/cuda_ipc/cuda_ipc_iface.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ static void uct_cuda_ipc_complete_event(uct_iface_h tl_iface,
311311
uct_cuda_ipc_event_desc_t);
312312
ucs_status_t status;
313313

314+
if (cuda_ipc_event->mapped_addr == NULL) {
315+
return; /* VMM_MULTI persistent mapping: cleanup at rkey_release */
316+
}
317+
314318
status = uct_cuda_ipc_unmap_memhandle(cuda_ipc_event->pid,
315319
cuda_ipc_event->pid_ns,
316320
cuda_ipc_event->d_bptr,

0 commit comments

Comments
 (0)