Skip to content

Commit ccaf7e2

Browse files
galletas1712claude
andcommitted
Refactor GMS client: enum lock types, simplified CUDA VMM, dataclass methods
Major changes: - Add RequestedLockType and GrantedLockType enums to replace string literals - Move client-only CUDA VMM utilities to client/cuda_vmm_utils.py - Simplify allocator.cpp to thin Python-calling shim (412 -> 104 lines) - Reduce tensor_from_pointer.cpp by removing contiguous version (77 -> 33 lines) - Consolidate TensorMeta/ParsedTensorMeta into single TensorMetadata class - Move standalone functions into dataclass methods (TensorMetadata, TensorIPCInfo, etc.) - Remove unused nbytes/span_bytes fields from tensor metadata Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 7c9c2ac commit ccaf7e2

File tree

15 files changed

+708
-1274
lines changed

15 files changed

+708
-1274
lines changed

lib/gpu_memory_service/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# Primary client exports
2727
from gpu_memory_service.client.memory_manager import (
2828
GMSClientMemoryManager,
29-
StaleWeightsError,
29+
StaleMemoryLayoutError,
3030
)
3131

3232
# PyTorch integration (lifecycle management)
@@ -38,7 +38,7 @@
3838
__all__ = [
3939
# Client
4040
"GMSClientMemoryManager",
41-
"StaleWeightsError",
41+
"StaleMemoryLayoutError",
4242
# Lifecycle
4343
"get_or_create_allocator",
4444
"get_allocator",

lib/gpu_memory_service/client/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
from gpu_memory_service.client.memory_manager import (
1616
GMSClientMemoryManager,
17-
StaleWeightsError,
17+
StaleMemoryLayoutError,
1818
)
1919
from gpu_memory_service.client.rpc import GMSRPCClient
2020

2121
__all__ = [
2222
"GMSClientMemoryManager",
23-
"StaleWeightsError",
23+
"StaleMemoryLayoutError",
2424
"GMSRPCClient",
2525
]
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Client-side CUDA VMM utilities.
5+
6+
These functions wrap CUDA driver API calls used by the client memory manager
7+
for importing, mapping, and unmapping GPU memory.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from cuda.bindings import driver as cuda
13+
from gpu_memory_service.common.cuda_vmm_utils import check_cuda_result
14+
from gpu_memory_service.common.types import GrantedLockType
15+
16+
17+
def import_handle_from_fd(fd: int) -> int:
18+
"""Import a CUDA memory handle from a file descriptor.
19+
20+
Args:
21+
fd: POSIX file descriptor received via SCM_RIGHTS.
22+
23+
Returns:
24+
CUDA memory handle.
25+
"""
26+
result, handle = cuda.cuMemImportFromShareableHandle(
27+
fd,
28+
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
29+
)
30+
check_cuda_result(result, "cuMemImportFromShareableHandle")
31+
return int(handle)
32+
33+
34+
def reserve_va(size: int, granularity: int) -> int:
35+
"""Reserve virtual address space.
36+
37+
Args:
38+
size: Size in bytes (should be aligned to granularity).
39+
granularity: VMM allocation granularity.
40+
41+
Returns:
42+
Reserved virtual address.
43+
"""
44+
result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
45+
check_cuda_result(result, "cuMemAddressReserve")
46+
return int(va)
47+
48+
49+
def free_va(va: int, size: int) -> None:
50+
"""Free a virtual address reservation.
51+
52+
Args:
53+
va: Virtual address to free.
54+
size: Size of the reservation.
55+
"""
56+
cuda.cuMemAddressFree(va, size)
57+
58+
59+
def map_to_va(va: int, size: int, handle: int) -> None:
60+
"""Map a CUDA handle to a virtual address.
61+
62+
Args:
63+
va: Virtual address (must be reserved).
64+
size: Size of the mapping.
65+
handle: CUDA memory handle.
66+
"""
67+
(result,) = cuda.cuMemMap(va, size, 0, handle, 0)
68+
check_cuda_result(result, "cuMemMap")
69+
70+
71+
def set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
72+
"""Set access permissions for a mapped region.
73+
74+
Args:
75+
va: Virtual address.
76+
size: Size of the region.
77+
device: CUDA device index.
78+
access: Access mode - RO for read-only, RW for read-write.
79+
"""
80+
acc = cuda.CUmemAccessDesc()
81+
acc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
82+
acc.location.id = device
83+
acc.flags = (
84+
cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
85+
if access == GrantedLockType.RO
86+
else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
87+
)
88+
(result,) = cuda.cuMemSetAccess(va, size, [acc], 1)
89+
check_cuda_result(result, "cuMemSetAccess")
90+
91+
92+
def unmap(va: int, size: int) -> None:
93+
"""Unmap a virtual address region.
94+
95+
Args:
96+
va: Virtual address to unmap.
97+
size: Size of the mapping.
98+
"""
99+
(result,) = cuda.cuMemUnmap(va, size)
100+
check_cuda_result(result, "cuMemUnmap")
101+
102+
103+
def release_handle(handle: int) -> None:
104+
"""Release a CUDA memory handle.
105+
106+
Args:
107+
handle: CUDA memory handle to release.
108+
"""
109+
(result,) = cuda.cuMemRelease(handle)
110+
check_cuda_result(result, "cuMemRelease")

0 commit comments

Comments
 (0)