diff --git a/.devcontainer/post-create.sh b/.devcontainer/post-create.sh index 4c203a3148f..858864af536 100755 --- a/.devcontainer/post-create.sh +++ b/.devcontainer/post-create.sh @@ -111,7 +111,8 @@ $SANITY_STATUS Now build the project: cargo build --locked --profile dev --features dynamo-llm/block-manager cd lib/bindings/python && maturin develop --uv - DYNAMO_BIN_PATH=$CARGO_TARGET_DIR/debug uv pip install -e . + uv pip install -e lib/gpu_memory_service # GPU memory manager with C++ extension + DYNAMO_BIN_PATH=\$CARGO_TARGET_DIR/debug uv pip install -e . Optional: cd lib/bindings/kvbm && maturin develop --uv # For KVBM support diff --git a/.dockerignore b/.dockerignore index c4e479600e2..2f1eaf01461 100644 --- a/.dockerignore +++ b/.dockerignore @@ -45,6 +45,10 @@ container/Dockerfile* .venv .venv-docs +# GPU Memory Service build artifacts +lib/gpu_memory_service/build/ +lib/gpu_memory_service/*.egg-info/ +lib/gpu_memory_service/**/*.so # Python __pycache__/ diff --git a/.github/filters.yaml b/.github/filters.yaml index 9ae96bbf433..98944c0df22 100644 --- a/.github/filters.yaml +++ b/.github/filters.yaml @@ -78,6 +78,7 @@ core: - 'components/src/dynamo/mocker/**' - 'components/src/dynamo/frontend/**' - 'components/src/dynamo/common/**' + - 'components/src/dynamo/gpu_memory_service/**' - '*.toml' - '*.lock' - '*.py' diff --git a/.gitignore b/.gitignore index 360637f3a9c..5c063ce7fcd 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ tensorrtllm_checkpoints/ tensorrtllm_engines/ api_server_models/ server/ +!lib/gpu_memory_service/server/ # Replay/Snapshot test artifacts *.new lib/llm/tests/data/sample-models/models--meta-llama--Llama-3.1-70B-Instruct/ diff --git a/README.md b/README.md index ac9d60ebef3..027403c5e62 100644 --- a/README.md +++ b/README.md @@ -331,7 +331,16 @@ cd lib/bindings/python maturin develop --uv ``` -## 6. Install the Wheel +## 6. Install GPU Memory Service + +The GPU Memory Service is a Python package with a C++ extension. It requires only Python development headers and a C++ compiler (g++). + +```bash +cd $PROJECT_ROOT +uv pip install -e lib/gpu_memory_service +``` + +## 7. Install the Wheel ``` cd $PROJECT_ROOT diff --git a/components/src/dynamo/gpu_memory_service/__init__.py b/components/src/dynamo/gpu_memory_service/__init__.py new file mode 100644 index 00000000000..78200cdb260 --- /dev/null +++ b/components/src/dynamo/gpu_memory_service/__init__.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service component for Dynamo. + +This module provides the Dynamo component wrapper around the gpu_memory_service package. +The core functionality is in the gpu_memory_service package; this module provides: +- CLI entry point (python -m dynamo.gpu_memory_service) +- Re-exports for backwards compatibility +""" + +# Re-export core functionality from gpu_memory_service package +from gpu_memory_service import ( + GMSClientMemoryManager, + StaleMemoryLayoutError, + get_gms_client_memory_manager, + get_or_create_gms_client_memory_manager, +) + +# Re-export extensions (built separately) +try: + from gpu_memory_service.client.torch.extensions import _allocator_ext +except (ImportError, OSError): + _allocator_ext = None + +# Re-export module utilities +from gpu_memory_service.client.torch.module import ( + materialize_module_from_gms, + register_module_tensors, +) + +__all__ = [ + # Core + "GMSClientMemoryManager", + "StaleMemoryLayoutError", + # GMS client memory manager + "get_or_create_gms_client_memory_manager", + "get_gms_client_memory_manager", + # Tensor utilities + "register_module_tensors", + "materialize_module_from_gms", + # Extensions + "_allocator_ext", +] diff --git a/components/src/dynamo/gpu_memory_service/__main__.py b/components/src/dynamo/gpu_memory_service/__main__.py new file mode 100644 index 00000000000..4a439ae666d --- /dev/null +++ b/components/src/dynamo/gpu_memory_service/__main__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from dynamo.gpu_memory_service.server import main + +if __name__ == "__main__": + main() diff --git a/components/src/dynamo/gpu_memory_service/args.py b/components/src/dynamo/gpu_memory_service/args.py new file mode 100644 index 00000000000..bf28dd9a379 --- /dev/null +++ b/components/src/dynamo/gpu_memory_service/args.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Argument parsing for GPU Memory Service server component.""" + +import argparse +import logging +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class Config: + """Configuration for GPU Memory Service server.""" + + # GPU Memory Service specific + device: int + socket_path: str + verbose: bool + + +def parse_args() -> Config: + """Parse command line arguments for GPU Memory Service server.""" + parser = argparse.ArgumentParser( + description="GPU Memory Service allocation server for Dynamo." + ) + + # GPU Memory Service specific arguments + parser.add_argument( + "--device", + type=int, + required=True, + help="CUDA device ID to manage memory for.", + ) + parser.add_argument( + "--socket-path", + type=str, + default=None, + help="Path for Unix domain socket. Default: /tmp/gpu_memory_service_{device}.sock. " + "Supports {device} placeholder for multi-GPU setups.", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging.", + ) + + args = parser.parse_args() + + # Generate default socket path if not provided + socket_path = args.socket_path + if socket_path is None: + socket_path = f"/tmp/gpu_memory_service_{args.device}.sock" + else: + # Expand {device} placeholder + socket_path = socket_path.format(device=args.device) + + config = Config( + device=args.device, + socket_path=socket_path, + verbose=args.verbose, + ) + + return config diff --git a/components/src/dynamo/gpu_memory_service/server.py b/components/src/dynamo/gpu_memory_service/server.py new file mode 100644 index 00000000000..3eb833733c7 --- /dev/null +++ b/components/src/dynamo/gpu_memory_service/server.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service allocation server component for Dynamo. + +This component wraps the GMSRPCServer from gpu_memory_service to manage +GPU memory allocations with connection-based RW/RO locking. + +Workers connect via the socket path, which should be passed to vLLM/SGLang via: + --load-format gpu_memory_service + --model-loader-extra-config '{"gpu_memory_service_socket_path": "/tmp/gpu_memory_service_{device}.sock"}' + +Usage: + python -m dynamo.gpu_memory_service --device 0 + python -m dynamo.gpu_memory_service --device 0 --socket-path /tmp/gpu_memory_service_{device}.sock +""" + +import asyncio +import logging +import signal + +import uvloop +from gpu_memory_service.server import GMSRPCServer + +from .args import parse_args + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +async def worker() -> None: + """Main async worker function.""" + config = parse_args() + + # Configure logging level + if config.verbose: + logging.getLogger().setLevel(logging.DEBUG) + logging.getLogger("dynamo.gpu_memory_service").setLevel(logging.DEBUG) + + logger.info(f"Starting GPU Memory Service Server for device {config.device}") + logger.info(f"Socket path: {config.socket_path}") + + server = GMSRPCServer(config.socket_path, device=config.device) + + # Set up shutdown handling + shutdown_event = asyncio.Event() + + def signal_handler(): + logger.info("Received shutdown signal") + shutdown_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + await server.start() + + logger.info("GPU Memory Service Server ready, waiting for connections...") + logger.info( + f"To connect vLLM workers, use: --load-format gpu_memory_service " + f'--model-loader-extra-config \'{{"gpu_memory_service_socket_path": "{config.socket_path}"}}\'' + ) + + # Wait for shutdown signal + try: + await shutdown_event.wait() + finally: + logger.info("Shutting down GPU Memory Service Server...") + await server.stop() + logger.info("GPU Memory Service Server shutdown complete") + + +def main() -> None: + """Entry point for GPU Memory Service server.""" + uvloop.install() + asyncio.run(worker()) + + +if __name__ == "__main__": + main() diff --git a/container/Dockerfile b/container/Dockerfile index c80307a5c5e..47bb60416af 100644 --- a/container/Dockerfile +++ b/container/Dockerfile @@ -27,6 +27,7 @@ ARG EPP_IMAGE="us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inferen ARG PYTHON_VERSION ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE ARG ENABLE_MEDIA_NIXL ARG ENABLE_MEDIA_FFMPEG ARG CARGO_BUILD_JOBS @@ -431,6 +432,13 @@ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ fi && \ /tmp/use-sccache.sh show-stats "Dynamo" +# Build gpu_memory_service wheel (C++ extension only needs Python headers, no CUDA/torch) +ARG ENABLE_GPU_MEMORY_SERVICE +RUN if [ "$ENABLE_GPU_MEMORY_SERVICE" = "true" ]; then \ + source ${VIRTUAL_ENV}/bin/activate && \ + uv build --wheel --out-dir /opt/dynamo/dist /opt/dynamo/lib/gpu_memory_service; \ + fi + ############################################## ########## Runtime image ############## ############################################## @@ -502,10 +510,19 @@ ENV VIRTUAL_ENV=/opt/dynamo/venv \ # Install dynamo wheels (runtime packages only, no test dependencies) ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE RUN uv pip install \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ /opt/dynamo/wheelhouse/nixl/nixl*.whl && \ + if [ "$ENABLE_GPU_MEMORY_SERVICE" = "true" ]; then \ + GMS_WHEEL=$(ls /opt/dynamo/wheelhouse/gpu_memory_service*.whl 2>/dev/null | head -1); \ + if [ -z "$GMS_WHEEL" ]; then \ + echo "ERROR: ENABLE_GPU_MEMORY_SERVICE is true but no gpu_memory_service wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + uv pip install "$GMS_WHEEL"; \ + fi && \ if [ "$ENABLE_KVBM" = "true" ]; then \ KVBM_WHEEL=$(ls /opt/dynamo/wheelhouse/kvbm*.whl 2>/dev/null | head -1); \ if [ -z "$KVBM_WHEEL" ]; then \ @@ -593,10 +610,19 @@ RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requi --requirement /tmp/requirements.test.txt ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE RUN uv pip install \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ /opt/dynamo/wheelhouse/nixl/nixl*.whl && \ + if [ "$ENABLE_GPU_MEMORY_SERVICE" = "true" ]; then \ + GMS_WHEEL=$(ls /opt/dynamo/wheelhouse/gpu_memory_service*.whl 2>/dev/null | head -1); \ + if [ -z "$GMS_WHEEL" ]; then \ + echo "ERROR: ENABLE_GPU_MEMORY_SERVICE is true but no gpu_memory_service wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + uv pip install "$GMS_WHEEL"; \ + fi && \ if [ "$ENABLE_KVBM" = "true" ]; then \ KVBM_WHEEL=$(ls /opt/dynamo/wheelhouse/kvbm*.whl 2>/dev/null | head -1); \ if [ -z "$KVBM_WHEEL" ]; then \ diff --git a/container/Dockerfile.sglang b/container/Dockerfile.sglang index 79682748501..3e7a0e87331 100644 --- a/container/Dockerfile.sglang +++ b/container/Dockerfile.sglang @@ -36,6 +36,7 @@ ARG BASE_IMAGE_TAG ARG PYTHON_VERSION ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE ARG ENABLE_MEDIA_NIXL ARG ENABLE_MEDIA_FFMPEG ARG CARGO_BUILD_JOBS @@ -442,6 +443,13 @@ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ fi && \ /tmp/use-sccache.sh show-stats "Dynamo" +# Build gpu_memory_service wheel (C++ extension only needs Python headers, no CUDA/torch) +ARG ENABLE_GPU_MEMORY_SERVICE +RUN if [ "$ENABLE_GPU_MEMORY_SERVICE" = "true" ]; then \ + source ${VIRTUAL_ENV}/bin/activate && \ + uv build --wheel --out-dir /opt/dynamo/dist /opt/dynamo/lib/gpu_memory_service; \ + fi + ################################## ########## Runtime Image ######### ################################## @@ -500,12 +508,21 @@ COPY --chmod=775 --chown=dynamo:0 --from=wheel_builder /workspace/nixl/build/src ENV SGLANG_VERSION="${RUNTIME_IMAGE_TAG%%-*}" # Install packages as root to ensure they go to system location (/usr/local/lib/python3.12/dist-packages) +ARG ENABLE_GPU_MEMORY_SERVICE RUN --mount=type=bind,source=.,target=/mnt/local_src \ pip install --no-cache-dir --break-system-packages \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ /opt/dynamo/wheelhouse/nixl/nixl*.whl \ - sglang==${SGLANG_VERSION} + sglang==${SGLANG_VERSION} && \ + if [ "${ENABLE_GPU_MEMORY_SERVICE}" = "true" ]; then \ + GMS_WHEEL=$(ls /opt/dynamo/wheelhouse/gpu_memory_service*.whl 2>/dev/null | head -1); \ + if [ -z "$GMS_WHEEL" ]; then \ + echo "ERROR: ENABLE_GPU_MEMORY_SERVICE is true but no gpu_memory_service wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + pip install --no-cache-dir --break-system-packages "$GMS_WHEEL"; \ + fi # Install common and test dependencies as root RUN --mount=type=bind,source=.,target=/mnt/local_src \ diff --git a/container/Dockerfile.trtllm b/container/Dockerfile.trtllm index 1aaaaa2ff8d..a0cce92b835 100644 --- a/container/Dockerfile.trtllm +++ b/container/Dockerfile.trtllm @@ -36,6 +36,7 @@ ARG BASE_IMAGE_TAG ARG PYTHON_VERSION ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE ARG ENABLE_MEDIA_NIXL ARG ENABLE_MEDIA_FFMPEG ARG CARGO_BUILD_JOBS @@ -454,6 +455,13 @@ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ fi && \ /tmp/use-sccache.sh show-stats "Dynamo" +# Build gpu_memory_service wheel (C++ extension only needs Python headers, no CUDA/torch) +ARG ENABLE_GPU_MEMORY_SERVICE +RUN if [ "$ENABLE_GPU_MEMORY_SERVICE" = "true" ]; then \ + source ${VIRTUAL_ENV}/bin/activate && \ + uv build --wheel --out-dir /opt/dynamo/dist /opt/dynamo/lib/gpu_memory_service; \ + fi + ################################################## ########## Framework Builder Stage ############## ################################################## @@ -770,12 +778,21 @@ COPY --chmod=775 --chown=dynamo:0 benchmarks/ /workspace/benchmarks/ # Install dynamo, NIXL, and dynamo-specific dependencies # Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE COPY --chmod=775 --chown=dynamo:0 --from=wheel_builder /opt/dynamo/dist/*.whl /opt/dynamo/wheelhouse/ RUN uv pip install \ --no-cache \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ /opt/dynamo/wheelhouse/nixl/nixl*.whl && \ + if [ "${ENABLE_GPU_MEMORY_SERVICE}" = "true" ]; then \ + GMS_WHEEL=$(ls /opt/dynamo/wheelhouse/gpu_memory_service*.whl 2>/dev/null | head -1); \ + if [ -z "$GMS_WHEEL" ]; then \ + echo "ERROR: ENABLE_GPU_MEMORY_SERVICE is true but no gpu_memory_service wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + uv pip install --no-cache "$GMS_WHEEL"; \ + fi && \ if [ "${ENABLE_KVBM}" = "true" ]; then \ KVBM_WHEEL=$(ls /opt/dynamo/wheelhouse/kvbm*.whl 2>/dev/null | head -1); \ if [ -z "$KVBM_WHEEL" ]; then \ diff --git a/container/Dockerfile.vllm b/container/Dockerfile.vllm index f10e91e4356..ee1f32d9ce7 100644 --- a/container/Dockerfile.vllm +++ b/container/Dockerfile.vllm @@ -41,6 +41,7 @@ ARG BASE_IMAGE_TAG ARG PYTHON_VERSION ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE ARG ENABLE_MEDIA_NIXL ARG ENABLE_MEDIA_FFMPEG ARG CARGO_BUILD_JOBS @@ -481,6 +482,13 @@ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ fi && \ /tmp/use-sccache.sh show-stats "Dynamo" +# Build gpu_memory_service wheel (C++ extension only needs Python headers, no CUDA/torch) +ARG ENABLE_GPU_MEMORY_SERVICE +RUN if [ "$ENABLE_GPU_MEMORY_SERVICE" = "true" ]; then \ + source ${VIRTUAL_ENV}/bin/activate && \ + uv build --wheel --out-dir /opt/dynamo/dist /opt/dynamo/lib/gpu_memory_service; \ + fi + ######################################################## ########## Framework Development Image ################ ######################################################## @@ -605,6 +613,7 @@ COPY --from=dynamo_base /usr/local/cuda/bin/fatbinary /usr/local/cuda/bin/fatbin COPY --from=dynamo_base /usr/local/cuda/include/ /usr/local/cuda/include/ COPY --from=dynamo_base /usr/local/cuda/nvvm /usr/local/cuda/nvvm COPY --from=dynamo_base /usr/local/cuda/lib64/libcudart.so* /usr/local/cuda/lib64/ +COPY --from=dynamo_base /usr/local/cuda/lib64/stubs/ /usr/local/cuda/lib64/stubs/ RUN CUDA_VERSION_MAJOR="${CUDA_VERSION%%.*}" &&\ ln -s /usr/local/cuda/lib64/libcublas.so.${CUDA_VERSION_MAJOR} /usr/local/cuda/lib64/libcublas.so &&\ ln -s /usr/local/cuda/lib64/libcublasLt.so.${CUDA_VERSION_MAJOR} /usr/local/cuda/lib64/libcublasLt.so @@ -744,11 +753,20 @@ COPY --chmod=775 --chown=dynamo:0 benchmarks/ /workspace/benchmarks/ # Install dynamo, NIXL, and dynamo-specific dependencies # Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not ARG ENABLE_KVBM +ARG ENABLE_GPU_MEMORY_SERVICE COPY --chmod=775 --chown=dynamo:0 --from=wheel_builder /opt/dynamo/dist/*.whl /opt/dynamo/wheelhouse/ RUN uv pip install \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ /opt/dynamo/wheelhouse/nixl/nixl*.whl && \ + if [ "${ENABLE_GPU_MEMORY_SERVICE}" = "true" ]; then \ + GMS_WHEEL=$(ls /opt/dynamo/wheelhouse/gpu_memory_service*.whl 2>/dev/null | head -1); \ + if [ -z "$GMS_WHEEL" ]; then \ + echo "ERROR: ENABLE_GPU_MEMORY_SERVICE is true but no gpu_memory_service wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + uv pip install "$GMS_WHEEL"; \ + fi && \ if [ "${ENABLE_KVBM}" = "true" ]; then \ KVBM_WHEEL=$(ls /opt/dynamo/wheelhouse/kvbm*.whl 2>/dev/null | head -1); \ if [ -z "$KVBM_WHEEL" ]; then \ @@ -823,6 +841,7 @@ RUN cd /usr/local/lib && \ ldconfig USER dynamo + ARG DYNAMO_COMMIT_SHA ENV DYNAMO_COMMIT_SHA=$DYNAMO_COMMIT_SHA diff --git a/container/build.sh b/container/build.sh index fca7ac7f9fe..953606a8a6d 100755 --- a/container/build.sh +++ b/container/build.sh @@ -156,6 +156,10 @@ PUSH="" # or can be explicitly enabled via --enable-kvbm flag ENABLE_KVBM=false +# GPU Memory Service - default disabled, enabled automatically for VLLM/SGLANG +# or can be explicitly enabled via --enable-gpu-memory-service flag +ENABLE_GPU_MEMORY_SERVICE=false + # sccache configuration for S3 USE_SCCACHE="" SCCACHE_BUCKET="" @@ -343,6 +347,9 @@ get_options() { --enable-kvbm) ENABLE_KVBM=true ;; + --enable-gpu-memory-service) + ENABLE_GPU_MEMORY_SERVICE=true + ;; --enable-media-nixl) ENABLE_MEDIA_NIXL=true ;; @@ -539,6 +546,7 @@ show_help() { echo " [--release-build perform a release build]" echo " [--make-efa Adds AWS EFA layer on top of the built image (works with any target)]" echo " [--enable-kvbm Enables KVBM support in Python 3.12]" + echo " [--enable-gpu-memory-service Enables GPU Memory Service support]" echo " [--enable-media-nixl Enable media processing with NIXL support (default: true for frameworks, false for none)]" echo " [--enable-media-ffmpeg Enable media processing with FFMPEG support (default: true for frameworks, false for none)]" echo " [--use-sccache enable sccache for Rust/C/C++ compilation caching]" @@ -831,6 +839,20 @@ if [[ ${ENABLE_KVBM} == "true" ]]; then BUILD_ARGS+=" --build-arg ENABLE_KVBM=${ENABLE_KVBM} " fi +# ENABLE_GPU_MEMORY_SERVICE: Used in Dockerfiles for gpu_memory_service wheel. +# Declared but not currently used in Dockerfile.trtllm. +# Force GPU Memory Service to be enabled for VLLM and SGLANG frameworks +if [[ $FRAMEWORK == "VLLM" ]] || [[ $FRAMEWORK == "SGLANG" ]]; then + echo "Forcing enable_gpu_memory_service to true in ${FRAMEWORK} image build" + ENABLE_GPU_MEMORY_SERVICE=true +fi +# For other frameworks, ENABLE_GPU_MEMORY_SERVICE defaults to false unless --enable-gpu-memory-service flag was provided + +if [[ ${ENABLE_GPU_MEMORY_SERVICE} == "true" ]]; then + echo "Enabling GPU Memory Service in the dynamo image" + BUILD_ARGS+=" --build-arg ENABLE_GPU_MEMORY_SERVICE=${ENABLE_GPU_MEMORY_SERVICE} " +fi + # ENABLE_MEDIA_NIXL: Enable media processing with NIXL support # Used in base Dockerfile for maturin build feature flag. # Can be explicitly overridden with --enable-media-nixl flag diff --git a/lib/gpu_memory_service/__init__.py b/lib/gpu_memory_service/__init__.py new file mode 100644 index 00000000000..3aa1b1dbbaa --- /dev/null +++ b/lib/gpu_memory_service/__init__.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service - out-of-process GPU memory manager. + +The GPU Memory Service decouples ownership of GPU memory from the processes +that use it, enabling zero-copy sharing and data survival across process crashes. + +Package structure: +- common/: Shared types and protocol (used by both server and client) +- server/: Allocation server daemon (no CUDA context required) +- client/: Client library for memory management + - client/torch/: PyTorch integration (allocator, tensor, module, extensions) + +Primary client API: + from gpu_memory_service import ( + GMSClientMemoryManager, + get_or_create_gms_client_memory_manager, + get_gms_client_memory_manager, + ) + +Server API: + from gpu_memory_service.server import GMSRPCServer +""" + +# Primary client exports +from gpu_memory_service.client.memory_manager import ( + GMSClientMemoryManager, + StaleMemoryLayoutError, +) + +# PyTorch integration (GMS client memory manager) +from gpu_memory_service.client.torch.allocator import ( + get_gms_client_memory_manager, + get_or_create_gms_client_memory_manager, +) + +__all__ = [ + # Client + "GMSClientMemoryManager", + "StaleMemoryLayoutError", + # GMS client memory manager + "get_or_create_gms_client_memory_manager", + "get_gms_client_memory_manager", +] diff --git a/lib/gpu_memory_service/client/__init__.py b/lib/gpu_memory_service/client/__init__.py new file mode 100644 index 00000000000..80eea63d86d --- /dev/null +++ b/lib/gpu_memory_service/client/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service client library. + +This module provides the client-side components for interacting with the +GPU Memory Service: + +- GMSClientMemoryManager: Manages local VA mappings of remote GPU memory +- GMSRPCClient: Low-level RPC client (pure Python, no PyTorch dependency) + +For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch. +""" + +from gpu_memory_service.client.memory_manager import ( + GMSClientMemoryManager, + StaleMemoryLayoutError, +) +from gpu_memory_service.client.rpc import GMSRPCClient + +__all__ = [ + "GMSClientMemoryManager", + "StaleMemoryLayoutError", + "GMSRPCClient", +] diff --git a/lib/gpu_memory_service/client/cuda_vmm_utils.py b/lib/gpu_memory_service/client/cuda_vmm_utils.py new file mode 100644 index 00000000000..9483530142d --- /dev/null +++ b/lib/gpu_memory_service/client/cuda_vmm_utils.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Client-side CUDA VMM utilities. + +These functions wrap CUDA driver API calls used by the client memory manager +for importing, mapping, and unmapping GPU memory. +""" + +from __future__ import annotations + +from cuda.bindings import driver as cuda +from gpu_memory_service.common.cuda_vmm_utils import check_cuda_result +from gpu_memory_service.common.types import GrantedLockType + + +def import_handle_from_fd(fd: int) -> int: + """Import a CUDA memory handle from a file descriptor. + + Args: + fd: POSIX file descriptor received via SCM_RIGHTS. + + Returns: + CUDA memory handle. + """ + result, handle = cuda.cuMemImportFromShareableHandle( + fd, + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + ) + check_cuda_result(result, "cuMemImportFromShareableHandle") + return int(handle) + + +def reserve_va(size: int, granularity: int) -> int: + """Reserve virtual address space. + + Args: + size: Size in bytes (should be aligned to granularity). + granularity: VMM allocation granularity. + + Returns: + Reserved virtual address. + """ + result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0) + check_cuda_result(result, "cuMemAddressReserve") + return int(va) + + +def free_va(va: int, size: int) -> None: + """Free a virtual address reservation. + + Args: + va: Virtual address to free. + size: Size of the reservation. + """ + (result,) = cuda.cuMemAddressFree(va, size) + check_cuda_result(result, "cuMemAddressFree") + + +def map_to_va(va: int, size: int, handle: int) -> None: + """Map a CUDA handle to a virtual address. + + Args: + va: Virtual address (must be reserved). + size: Size of the mapping. + handle: CUDA memory handle. + """ + (result,) = cuda.cuMemMap(va, size, 0, handle, 0) + check_cuda_result(result, "cuMemMap") + + +def set_access(va: int, size: int, device: int, access: GrantedLockType) -> None: + """Set access permissions for a mapped region. + + Args: + va: Virtual address. + size: Size of the region. + device: CUDA device index. + access: Access mode - RO for read-only, RW for read-write. + """ + acc = cuda.CUmemAccessDesc() + acc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + acc.location.id = device + acc.flags = ( + cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ + if access == GrantedLockType.RO + else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + ) + (result,) = cuda.cuMemSetAccess(va, size, [acc], 1) + check_cuda_result(result, "cuMemSetAccess") + + +def unmap(va: int, size: int) -> None: + """Unmap a virtual address region. + + Args: + va: Virtual address to unmap. + size: Size of the mapping. + """ + (result,) = cuda.cuMemUnmap(va, size) + check_cuda_result(result, "cuMemUnmap") + + +def release_handle(handle: int) -> None: + """Release a CUDA memory handle. + + Args: + handle: CUDA memory handle to release. + """ + (result,) = cuda.cuMemRelease(handle) + check_cuda_result(result, "cuMemRelease") diff --git a/lib/gpu_memory_service/client/memory_manager.py b/lib/gpu_memory_service/client/memory_manager.py new file mode 100644 index 00000000000..149b3d18034 --- /dev/null +++ b/lib/gpu_memory_service/client/memory_manager.py @@ -0,0 +1,650 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service client-side memory manager. + +This is the unified memory manager for the GPU Memory Service architecture. + +Key properties: +- Uses GMSRPCClient over a Unix-domain socket. +- The socket connection itself is the RW/RO lock. +- In write mode, the manager can allocate + map RW and then publish via commit(). +- In read mode, the manager can import + map RO and hold the RO lock during inference. +- sleep()/wake() releases and reacquires the RO lock (and remaps allocations). + +This module uses cuda-python bindings for CUDA driver API calls: +- import FDs (cuMemImportFromShareableHandle) +- reserve VA (cuMemAddressReserve) +- map/unmap (cuMemMap/cuMemUnmap) +- enforce access (cuMemSetAccess) +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +from cuda.bindings import driver as cuda +from gpu_memory_service.client.cuda_vmm_utils import ( + free_va, + import_handle_from_fd, + map_to_va, + release_handle, + reserve_va, + set_access, + unmap, +) +from gpu_memory_service.client.rpc import GMSRPCClient +from gpu_memory_service.common.cuda_vmm_utils import ( + align_to_granularity, + get_allocation_granularity, +) +from gpu_memory_service.common.types import GrantedLockType, RequestedLockType + +logger = logging.getLogger(__name__) + + +class StaleMemoryLayoutError(Exception): + """Raised when memory layout was modified while sleeping. + + This error indicates that a writer acquired the RW lock and changed the + allocation structure (different sizes, different tensor layouts) while this + reader was sleeping. The caller should re-import the model from scratch. + + IMPORTANT: This is a LAYOUT check, NOT a CONTENT check. + - Detected: Allocation sizes changed, tensors added/removed, metadata structure changed + - NOT detected: Weight values modified in-place + + This design is intentional: sleep/wake enables use cases like RL training + where another process can write to the same memory locations (e.g., updating + weights) while preserving the structure. As long as the layout (allocation + and metadata table hashes) remains identical, wake() succeeds. + """ + + pass + + +@dataclass(frozen=True) +class LocalMapping: + """Immutable record of a local VA mapping.""" + + allocation_id: str + va: int + size: int + aligned_size: int + handle: int # 0 if unmapped but VA reserved + tag: str + access: GrantedLockType + + def with_handle(self, handle: int) -> "LocalMapping": + return LocalMapping( + self.allocation_id, + self.va, + self.size, + self.aligned_size, + handle, + self.tag, + self.access, + ) + + def with_access(self, access: GrantedLockType) -> "LocalMapping": + return LocalMapping( + self.allocation_id, + self.va, + self.size, + self.aligned_size, + self.handle, + self.tag, + access, + ) + + +class GMSClientMemoryManager: + """Unified memory manager that can act as writer or reader. + + Modes: + - mode=RequestedLockType.RW: acquire RW lock, allocate/map RW, mutate metadata, commit/publish. + - mode=RequestedLockType.RO: acquire RO lock (READY only), import/map RO, sleep/wake. + - mode=RequestedLockType.RW_OR_RO: try RW if available, else wait for RO. + """ + + def __init__( + self, + socket_path: str, + *, + mode: RequestedLockType, + device: int = 0, + timeout_ms: Optional[int] = None, + ) -> None: + self.socket_path = socket_path + self.device = device + self._timeout_ms = timeout_ms + + self._client: Optional[GMSRPCClient] = None + self._mappings: Dict[int, LocalMapping] = {} # va -> mapping + self._allocation_id_to_va: Dict[str, int] = {} + + self._sleeping = False + self._closed = False + self._preserved_allocation_ids: List[str] = [] + self._published = False + self._mode: Optional[GrantedLockType] = None # Updated by _connect + + # VA-stable sleep/wake state + self._va_preserved = False + self._last_memory_layout_hash: str = ( + "" # Hash from server, saved on connect/commit + ) + + # Ensure torch is on the right device for subsequent CUDA operations. + if torch.cuda.is_available(): + torch.cuda.set_device(self.device) + + # Cache granularity for VA alignment + self.granularity = get_allocation_granularity(device) + + self._connect(lock_type=mode, timeout_ms=timeout_ms) + + def _connect( + self, + *, + lock_type: RequestedLockType, + timeout_ms: Optional[int], + update_memory_layout_hash: bool = True, + ) -> None: + self._client = GMSRPCClient( + self.socket_path, lock_type=lock_type, timeout_ms=timeout_ms + ) + self._sleeping = False + # Update mode based on granted lock type (may differ from requested for rw_or_ro) + self._mode = self._client.lock_type + # Save state hash for stale detection on wake (skip during wake itself) + if update_memory_layout_hash and self._client.committed: + self._last_memory_layout_hash = self._client.get_memory_layout_hash() + + @property + def mode(self) -> Optional[GrantedLockType]: + """Current mode of the memory manager.""" + return self._mode + + @property + def lock_type(self) -> Optional[GrantedLockType]: + """Get the lock type actually granted by the server.""" + if self._client is None: + return None + return self._client.lock_type + + @property + def is_connected(self) -> bool: + return self._client is not None and self._client.is_connected + + @property + def is_sleeping(self) -> bool: + return self._sleeping + + @property + def mappings(self) -> Dict[int, LocalMapping]: + """Read-only view of VA -> LocalMapping dictionary.""" + return self._mappings + + @property + def total_bytes(self) -> int: + """Total bytes allocated across all mappings.""" + return sum(m.aligned_size for m in self._mappings.values()) + + # ==================== Metadata convenience ==================== + + def metadata_put( + self, key: str, allocation_id: str, offset_bytes: int, value: bytes + ) -> bool: + return self._client_rpc.metadata_put(key, allocation_id, offset_bytes, value) + + def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]: + return self._client_rpc.metadata_get(key) + + def metadata_list(self, prefix: str = "") -> List[str]: + return self._client_rpc.metadata_list(prefix) + + def metadata_delete(self, key: str) -> bool: + return self._client_rpc.metadata_delete(key) + + # ==================== Allocation operations ==================== + + def list_allocations(self, tag: Optional[str] = None) -> List[Dict]: + """List all allocations on the server.""" + return self._client_rpc.list_allocations(tag) + + def allocate_and_map(self, size: int, tag: str = "default") -> int: + """Allocate on server, reserve VA, and map locally. + + Args: + size: Requested allocation size in bytes. + tag: Allocation tag for server tracking. + + Returns: + Virtual address of the mapped allocation. + """ + self._require_rw() + client = self._client_rpc + aligned_size = align_to_granularity(size, self.granularity) + + va = reserve_va(aligned_size, self.granularity) + try: + allocation_id, server_aligned = client.allocate(aligned_size, tag) + if int(server_aligned) != aligned_size: + raise RuntimeError( + f"Alignment mismatch: {aligned_size} vs {server_aligned}" + ) + + fd = client.export(allocation_id) + handle = import_handle_from_fd(fd) + map_to_va(va, aligned_size, handle) + set_access(va, aligned_size, self.device, GrantedLockType.RW) + + self._track_mapping( + LocalMapping( + allocation_id=allocation_id, + va=va, + size=size, + aligned_size=aligned_size, + handle=handle, + tag=tag, + access=GrantedLockType.RW, + ) + ) + return va + except Exception: + free_va(va, aligned_size) + raise + + def free_mapping(self, va: int) -> None: + """Unmap and free a local mapping.""" + mapping = self._mappings.pop(va, None) + if mapping is None: + return + + self._allocation_id_to_va.pop(mapping.allocation_id, None) + + try: + if mapping.handle != 0: + unmap(va, mapping.aligned_size) + release_handle(mapping.handle) + free_va(va, mapping.aligned_size) + except Exception as e: + logger.warning(f"Error freeing VA 0x{va:x}: {e}") + + if self.lock_type == GrantedLockType.RW and not self._published: + try: + self._client_rpc.free(mapping.allocation_id) + except Exception: + pass + + def import_allocation(self, allocation_id: str) -> int: + """Import an existing allocation and map locally. + + In RO mode, maps read-only. In RW mode, maps read-write. + """ + if allocation_id in self._allocation_id_to_va: + return self._allocation_id_to_va[allocation_id] + + client = self._client_rpc + # lock_type is guaranteed non-None when connected (after _client_rpc succeeds) + assert self.lock_type is not None + current_access = self.lock_type + alloc_info = client.get_allocation(allocation_id) + aligned_size = int(alloc_info.aligned_size) + size = int(alloc_info.size) + tag = str(getattr(alloc_info, "tag", "default")) + + va = reserve_va(aligned_size, self.granularity) + try: + fd = client.export(allocation_id) + handle = import_handle_from_fd(fd) + map_to_va(va, aligned_size, handle) + set_access(va, aligned_size, self.device, current_access) + + self._track_mapping( + LocalMapping( + allocation_id=allocation_id, + va=va, + size=size, + aligned_size=aligned_size, + handle=handle, + tag=tag, + access=current_access, + ) + ) + return va + except Exception: + free_va(va, aligned_size) + raise + + def clear_all(self) -> int: + """Clear all allocations on the server (RW only). Local mappings are unmapped first.""" + self._require_rw() + self._unmap_all() + return self._client_rpc.clear_all() + + # ==================== Publish / mode switching ==================== + + def commit(self) -> bool: + """Publish weights (RW only). + + Client responsibilities: + - cudaDeviceSynchronize() before publishing + - flip local mappings to RO before publishing + + Server responsibilities: + - transition to COMMITTED + - close the RW socket (publish + release) + """ + self._require_rw() + + if torch.cuda.is_available(): + torch.cuda.synchronize(self.device) + + # After publishing, prevent further writes locally. + for va, m in list(self._mappings.items()): + if m.access != GrantedLockType.RO: + set_access(m.va, m.aligned_size, self.device, GrantedLockType.RO) + self._mappings[va] = m.with_access(GrantedLockType.RO) + + ok = self._client_rpc.commit() + self._published = bool(ok) + # _client.commit() closes the socket on success; reflect that here. + if ok: + self._client = None + return bool(ok) + + def switch_to_read(self, timeout_ms: Optional[int] = None) -> None: + """Acquire an RO lock after publishing. + + This is intended for the common flow where a writer loads weights and then + becomes a reader for inference. + """ + if self._closed: + raise RuntimeError("Memory manager is closed") + if self._sleeping: + raise RuntimeError( + "Cannot switch_to_read() while sleeping; call wake() first" + ) + if self._client is not None: + if self.lock_type == GrantedLockType.RO: + return + raise RuntimeError( + "switch_to_read() requires the RW connection to be released (call commit() first)" + ) + + eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms + self._connect(lock_type=RequestedLockType.RO, timeout_ms=eff_timeout) + + # ==================== Sleep / wake (read mode) ==================== + + def sleep(self) -> None: + """Release RO lock and unmap local allocations (VA-stable). + + VAs are preserved during sleep so tensor pointers remain stable. + On wake, allocations are remapped to the same VAs. + """ + if self._closed: + raise RuntimeError("Memory manager is closed") + if self._sleeping: + return + if self.lock_type != GrantedLockType.RO: + raise RuntimeError("sleep() requires RO mode") + + if torch.cuda.is_available(): + torch.cuda.synchronize(self.device) + + # Preserve allocation IDs for remapping on wake + self._preserved_allocation_ids = list(self._allocation_id_to_va.keys()) + + # Unmap physical memory but keep VA reservations + self._unmap_preserving_va() + self._va_preserved = True + + self._client_rpc.close() + self._client = None + self._sleeping = True + + def wake(self, timeout_ms: Optional[int] = None) -> bool: + """Reacquire RO lock and remap preserved allocations (VA-stable). + + Allocations are remapped to the same VAs they had before sleep, + ensuring tensor pointers remain valid. + + Args: + timeout_ms: Timeout for RO lock acquisition. + + Returns: + True on success. + + Raises: + TimeoutError: If timeout_ms expires waiting for RO lock. + StaleMemoryLayoutError: If weights were structurally changed while sleeping. + """ + if self._closed: + raise RuntimeError("Memory manager is closed") + if not self._sleeping: + return True + + if torch.cuda.is_available(): + torch.cuda.set_device(self.device) + + eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms + self._connect( + lock_type=RequestedLockType.RO, + timeout_ms=eff_timeout, + update_memory_layout_hash=False, + ) + + # Check if memory layout changed while sleeping + current_hash = self._client_rpc.get_memory_layout_hash() + if ( + self._last_memory_layout_hash + and current_hash != self._last_memory_layout_hash + ): + raise StaleMemoryLayoutError( + f"State changed while sleeping: hash {self._last_memory_layout_hash[:16]}... -> {current_hash[:16]}..." + ) + + # Remap to preserved VAs + remapped_count = 0 + failed_count = 0 + total_bytes = 0 + for alloc_id in self._preserved_allocation_ids: + try: + va = self._remap_preserved_va(alloc_id) + mapping = self._mappings.get(va) + if mapping: + total_bytes += mapping.aligned_size + remapped_count += 1 + except StaleMemoryLayoutError: + raise # Let StaleMemoryLayoutError propagate + except Exception as e: + logger.warning(f"Failed to remap {alloc_id}: {e}") + failed_count += 1 + + if failed_count > 0: + raise RuntimeError( + f"Wake failed: {failed_count} of {len(self._preserved_allocation_ids)} " + f"allocations could not be remapped" + ) + + logger.info( + f"[GPU Memory Service] Wake complete on device {self.device}: " + f"remapped {remapped_count} allocations ({total_bytes / (1 << 30):.2f} GiB)" + ) + + self._sleeping = False + self._va_preserved = False + return True + + # ==================== Cleanup ==================== + + def close(self) -> None: + if self._closed: + return + + # Ensure kernels are done before tearing down mappings. + if torch.cuda.is_available(): + torch.cuda.synchronize(self.device) + + # Release all mappings including preserved VA reservations + self._unmap_all() + + if self._client is not None: + self._client.close() + self._client = None + self._closed = True + self._sleeping = False + self._va_preserved = False + self._preserved_allocation_ids.clear() + + def __enter__(self) -> "GMSClientMemoryManager": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + # ==================== Internals ==================== + + @property + def _client_rpc(self) -> GMSRPCClient: + """Get connected client or raise. Use instead of _require_connected() + assert.""" + if self._client is None: + if self._sleeping: + raise RuntimeError("Memory manager is sleeping") + raise RuntimeError("Memory manager is not connected") + return self._client + + def _require_rw(self) -> None: + """Raise if not in RW mode.""" + if self.lock_type != GrantedLockType.RW: + raise RuntimeError("Operation requires RW mode") + + def _track_mapping(self, m: LocalMapping) -> None: + self._mappings[m.va] = m + self._allocation_id_to_va[m.allocation_id] = m.va + + def _unmap_preserving_va(self) -> None: + """Unmap physical memory but PRESERVE VA reservations for sleep/wake. + + This keeps the VA reservation intact so tensors maintain stable pointers. + On wake, we can remap to the same VAs. + """ + unmapped_count = 0 + total_bytes = 0 + for va, mapping in list(self._mappings.items()): + if mapping.handle == 0: + continue # Already unmapped + try: + unmap(va, mapping.aligned_size) + release_handle(mapping.handle) + self._mappings[va] = mapping.with_handle( + 0 + ) # Mark unmapped, VA reserved + unmapped_count += 1 + total_bytes += mapping.aligned_size + except Exception as e: + logger.warning( + f"Error unmapping VA 0x{va:x} (preserving reservation): {e}" + ) + logger.info( + f"[GPU Memory Service] Unmapped {unmapped_count} allocations ({total_bytes / (1 << 30):.2f} GiB), " + f"preserving {len(self._mappings)} VA reservations" + ) + + def _remap_preserved_va(self, allocation_id: str) -> int: + """Remap an allocation to its preserved VA. + + Requires the VA to already be reserved (from before sleep). + Validates allocation still exists and size matches. + + Returns the VA. + Raises StaleMemoryLayoutError if allocation is missing or size changed. + """ + if torch.cuda.is_available(): + torch.cuda.set_device(self.device) + + va = self._allocation_id_to_va.get(allocation_id) + if va is None: + raise RuntimeError(f"No preserved VA for allocation {allocation_id}") + + mapping = self._mappings.get(va) + if mapping is None: + raise RuntimeError(f"No mapping info for VA 0x{va:x}") + + if mapping.handle != 0: + return va # Already mapped + + client = self._client_rpc + # lock_type is guaranteed non-None when connected (after _client_rpc succeeds) + assert self.lock_type is not None + current_access = self.lock_type + + # Validate allocation still exists and size matches + try: + alloc_info = client.get_allocation(allocation_id) + except Exception as e: + raise StaleMemoryLayoutError( + f"Allocation {allocation_id} no longer exists on server: {e}" + ) from e + + server_aligned_size = int(alloc_info.aligned_size) + if server_aligned_size != mapping.aligned_size: + raise StaleMemoryLayoutError( + f"Allocation {allocation_id} size changed: expected {mapping.aligned_size}, got {server_aligned_size}" + ) + + # Re-import the handle and map to the SAME VA (which is still reserved) + fd = client.export(allocation_id) + handle = import_handle_from_fd(fd) + map_to_va(va, mapping.aligned_size, handle) + + # Set access permissions based on current lock type + set_access(va, mapping.aligned_size, self.device, current_access) + + # Synchronize to ensure mapping is complete before any access + cuda.cuCtxSynchronize() + + # Validate the pointer is accessible (this is what Triton checks) + result, _dev_ptr = cuda.cuPointerGetAttribute( + cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va + ) + if result != cuda.CUresult.CUDA_SUCCESS: + err_result, err_str = cuda.cuGetErrorString(result) + err_msg = "" + if err_result == cuda.CUresult.CUDA_SUCCESS and err_str: + err_msg = ( + err_str.decode() if isinstance(err_str, bytes) else str(err_str) + ) + logger.warning( + f"[GPU Memory Service] cuPointerGetAttribute failed for VA 0x{va:x} after remap: " + f"error {result} ({err_msg})" + ) + else: + logger.debug( + f"[GPU Memory Service] Remapped VA 0x{va:x} validated OK (device={self.device})" + ) + + # Update mapping with new handle and access + updated = mapping.with_handle(handle) + self._mappings[va] = updated.with_access(current_access) + + return va + + def _unmap_all(self) -> None: + """Unmap and release all local mappings including VA reservations.""" + for va, mapping in list(self._mappings.items()): + try: + if mapping.handle != 0: + unmap(va, mapping.aligned_size) + release_handle(mapping.handle) + free_va(va, mapping.aligned_size) + except Exception as e: + logger.warning(f"Error unmapping VA 0x{va:x}: {e}") + self._mappings.clear() + self._allocation_id_to_va.clear() diff --git a/lib/gpu_memory_service/client/rpc.py b/lib/gpu_memory_service/client/rpc.py new file mode 100644 index 00000000000..8c359188e26 --- /dev/null +++ b/lib/gpu_memory_service/client/rpc.py @@ -0,0 +1,338 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service RPC Client. + +Low-level RPC client stub. The client provides a simple interface for acquiring +locks and performing allocation operations. The socket connection IS the lock. + +This module has NO PyTorch dependency. + +Usage: + # Writer (acquires RW lock in constructor) + with GMSRPCClient(socket_path, lock_type=RequestedLockType.RW) as client: + alloc_id, aligned_size = client.allocate(size=1024*1024) + fd = client.export(alloc_id) + # ... write weights using fd ... + client.commit() + # Lock released on exit + + # Reader (acquires RO lock in constructor) + client = GMSRPCClient(socket_path, lock_type=RequestedLockType.RO) + if client.committed: # Check if weights are valid + allocations = client.list_allocations() + for alloc in allocations: + fd = client.export(alloc["allocation_id"]) + # ... import and map fd ... + # Keep connection open during inference! + # client.close() only when done with inference +""" + +import logging +import socket +from typing import Dict, List, Optional, Tuple, Type, TypeVar + +from gpu_memory_service.common.protocol.messages import ( + AllocateRequest, + AllocateResponse, + ClearAllRequest, + ClearAllResponse, + CommitRequest, + CommitResponse, + ErrorResponse, + ExportRequest, + FreeRequest, + FreeResponse, + GetAllocationRequest, + GetAllocationResponse, + GetAllocationStateRequest, + GetAllocationStateResponse, + GetLockStateRequest, + GetLockStateResponse, + GetStateHashRequest, + GetStateHashResponse, + HandshakeRequest, + HandshakeResponse, + ListAllocationsRequest, + ListAllocationsResponse, + MetadataDeleteRequest, + MetadataDeleteResponse, + MetadataGetRequest, + MetadataGetResponse, + MetadataListRequest, + MetadataListResponse, + MetadataPutRequest, + MetadataPutResponse, +) +from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync +from gpu_memory_service.common.types import ( + RW_REQUIRED, + GrantedLockType, + RequestedLockType, +) + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +class GMSRPCClient: + """GPU Memory Service RPC Client. + + CRITICAL: Socket connection IS the lock. + - Constructor blocks until lock is acquired + - close() releases the lock + - committed property tells readers if weights are valid + + For writers (lock_type=RequestedLockType.RW): + - Use context manager (with statement) for automatic lock release + - Call commit() after weights are written + - Call clear_all() before loading new model + + For readers (lock_type=RequestedLockType.RO): + - Check committed property after construction + - Keep connection open during inference lifetime + - Only call close() when shutting down or allowing weight updates + """ + + def __init__( + self, + socket_path: str, + lock_type: RequestedLockType = RequestedLockType.RO, + timeout_ms: Optional[int] = None, + ): + """Connect to Allocation Server and acquire lock. + + Args: + socket_path: Path to server's Unix domain socket + lock_type: Requested lock type (RW, RO, or RW_OR_RO) + timeout_ms: Timeout in milliseconds for lock acquisition. + None means wait indefinitely. + + Raises: + ConnectionError: If connection fails + TimeoutError: If timeout_ms expires waiting for lock + """ + self.socket_path = socket_path + self._requested_lock_type = lock_type + self._socket: Optional[socket.socket] = None + self._recv_buffer = bytearray() + self._committed = False + self._granted_lock_type: Optional[GrantedLockType] = None + + # Connect and acquire lock + self._connect(timeout_ms=timeout_ms) + + def _connect(self, timeout_ms: Optional[int]) -> None: + """Connect to server and perform handshake (lock acquisition).""" + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + self._socket.connect(self.socket_path) + except FileNotFoundError: + raise ConnectionError(f"Server not running at {self.socket_path}") from None + except Exception as e: + raise ConnectionError(f"Failed to connect: {e}") from e + + # Send handshake (this IS lock acquisition) + request = HandshakeRequest( + lock_type=self._requested_lock_type, timeout_ms=timeout_ms + ) + send_message_sync(self._socket, request) + + # Receive response (may block waiting for lock) + response, _, self._recv_buffer = recv_message_sync( + self._socket, self._recv_buffer + ) + + if isinstance(response, ErrorResponse): + self._socket.close() + self._socket = None + raise ConnectionError(f"Handshake error: {response.error}") + + if not isinstance(response, HandshakeResponse): + self._socket.close() + self._socket = None + raise ConnectionError(f"Unexpected response: {type(response)}") + + if not response.success: + self._socket.close() + self._socket = None + raise TimeoutError("Timeout waiting for lock") + + self._committed = response.committed + # Store granted lock type (may differ from requested for rw_or_ro mode) + if response.granted_lock_type is not None: + self._granted_lock_type = response.granted_lock_type + elif self._requested_lock_type == RequestedLockType.RW: + self._granted_lock_type = GrantedLockType.RW + else: + self._granted_lock_type = GrantedLockType.RO + logger.info( + f"Connected with {self._requested_lock_type.value} lock (granted={self._granted_lock_type.value}), " + f"committed={self._committed}" + ) + + @property + def committed(self) -> bool: + """Check if weights are committed (valid).""" + return self._committed + + @property + def lock_type(self) -> Optional[GrantedLockType]: + """Get the lock type actually granted by the server. + + For rw_or_ro mode, this tells you whether RW or RO was granted. + """ + return self._granted_lock_type + + @property + def is_connected(self) -> bool: + """Check if client is connected.""" + return self._socket is not None + + def _send_recv(self, request) -> Tuple[object, int]: + """Send request and receive response. Returns (response, fd).""" + if not self._socket: + raise RuntimeError("Client not connected") + + send_message_sync(self._socket, request) + response, fd, self._recv_buffer = recv_message_sync( + self._socket, self._recv_buffer + ) + + if isinstance(response, ErrorResponse): + raise RuntimeError(f"Server error: {response.error}") + + return response, fd + + def _call(self, request, response_type: Type[T]) -> T: + """Send request, validate response type, return typed response.""" + if type(request) in RW_REQUIRED and self.lock_type != GrantedLockType.RW: + raise RuntimeError("Operation requires RW connection") + response, _ = self._send_recv(request) + if not isinstance(response, response_type): + raise RuntimeError(f"Unexpected response: {type(response)}") + return response + + def get_lock_state(self) -> GetLockStateResponse: + return self._call(GetLockStateRequest(), GetLockStateResponse) + + def get_allocation_state(self) -> GetAllocationStateResponse: + return self._call(GetAllocationStateRequest(), GetAllocationStateResponse) + + def is_ready(self) -> bool: + return self.committed + + def commit(self) -> bool: + """Commit weights and release RW lock. Returns True on success.""" + if CommitRequest in RW_REQUIRED and self.lock_type != GrantedLockType.RW: + raise RuntimeError("Operation requires RW connection") + + try: + response, _ = self._send_recv(CommitRequest()) + ok = isinstance(response, CommitResponse) and response.success + except (ConnectionResetError, BrokenPipeError, OSError) as e: + # Server closes RW socket as part of commit + logger.debug( + f"Commit saw socket error ({type(e).__name__}); verifying via RO connect" + ) + self.close() + try: + ro = GMSRPCClient( + self.socket_path, lock_type=RequestedLockType.RO, timeout_ms=1000 + ) + try: + ok = ro.committed + finally: + ro.close() + except TimeoutError: + ok = False + + if ok: + self._committed = True + self.close() + logger.info("Committed weights and released RW connection") + return True + + return False + + def allocate(self, size: int, tag: str = "default") -> Tuple[str, int]: + """Returns (allocation_id, aligned_size).""" + r = self._call(AllocateRequest(size=size, tag=tag), AllocateResponse) + return r.allocation_id, r.aligned_size + + def export(self, allocation_id: str) -> int: + """Export allocation as POSIX FD. Caller must close.""" + _, fd = self._send_recv(ExportRequest(allocation_id=allocation_id)) + if fd < 0: + raise RuntimeError("No FD received from server") + return fd + + def get_allocation(self, allocation_id: str) -> GetAllocationResponse: + return self._call( + GetAllocationRequest(allocation_id=allocation_id), GetAllocationResponse + ) + + def list_allocations(self, tag: Optional[str] = None) -> List[Dict]: + return self._call( + ListAllocationsRequest(tag=tag), ListAllocationsResponse + ).allocations + + def free(self, allocation_id: str) -> bool: + return self._call( + FreeRequest(allocation_id=allocation_id), FreeResponse + ).success + + def clear_all(self) -> int: + return self._call(ClearAllRequest(), ClearAllResponse).cleared_count + + def metadata_put( + self, key: str, allocation_id: str, offset_bytes: int, value: bytes + ) -> bool: + req = MetadataPutRequest( + key=key, allocation_id=allocation_id, offset_bytes=offset_bytes, value=value + ) + return self._call(req, MetadataPutResponse).success + + def metadata_get(self, key: str) -> Optional[tuple[str, int, bytes]]: + """Returns (allocation_id, offset_bytes, value) or None if not found.""" + r = self._call(MetadataGetRequest(key=key), MetadataGetResponse) + return (r.allocation_id, r.offset_bytes, r.value) if r.found else None + + def metadata_delete(self, key: str) -> bool: + return self._call( + MetadataDeleteRequest(key=key), MetadataDeleteResponse + ).deleted + + def metadata_list(self, prefix: str = "") -> List[str]: + return self._call(MetadataListRequest(prefix=prefix), MetadataListResponse).keys + + def get_memory_layout_hash(self) -> str: + """Get state hash (hash of allocations + metadata). Empty if not committed.""" + return self._call( + GetStateHashRequest(), GetStateHashResponse + ).memory_layout_hash + + def close(self) -> None: + """Close connection and release lock.""" + if self._socket: + try: + self._socket.close() + except Exception: + pass + self._socket = None + lock_str = self.lock_type.value if self.lock_type else "unknown" + logger.info(f"Closed {lock_str} connection") + + def __enter__(self) -> "GMSRPCClient": + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.close() + + def __del__(self): + """Destructor: warn if connection not closed.""" + if self._socket: + logger.warning("GMSRPCClient not closed properly") diff --git a/lib/gpu_memory_service/client/torch/__init__.py b/lib/gpu_memory_service/client/torch/__init__.py new file mode 100644 index 00000000000..a7af2f8cc9b --- /dev/null +++ b/lib/gpu_memory_service/client/torch/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""PyTorch integration for GPU Memory Service. + +This module provides PyTorch-specific functionality: + +- Memory manager singleton management +- Tensor utilities (metadata, registration, materialization) +- C++ extension for CUDAPluggableAllocator +""" + +from gpu_memory_service.client.torch.allocator import ( + get_gms_client_memory_manager, + get_or_create_gms_client_memory_manager, +) +from gpu_memory_service.client.torch.module import ( + materialize_module_from_gms, + register_module_tensors, +) + +__all__ = [ + # GMS client memory manager + "get_or_create_gms_client_memory_manager", + "get_gms_client_memory_manager", + # Tensor operations (public API) + "register_module_tensors", + "materialize_module_from_gms", +] diff --git a/lib/gpu_memory_service/client/torch/allocator.py b/lib/gpu_memory_service/client/torch/allocator.py new file mode 100644 index 00000000000..8ecbb244c58 --- /dev/null +++ b/lib/gpu_memory_service/client/torch/allocator.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service allocator singleton management. + +Manages the singleton memory manager and PyTorch MemPool integration. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Optional, Tuple + +from gpu_memory_service.common.types import GrantedLockType, RequestedLockType + +if TYPE_CHECKING: + from gpu_memory_service.client.memory_manager import GMSClientMemoryManager + from torch.cuda.memory import MemPool + +logger = logging.getLogger(__name__) + +# Global singleton state +_gms_client_memory_manager: Optional["GMSClientMemoryManager"] = None +_mem_pool: Optional["MemPool"] = None +_pluggable_alloc: Optional[Any] = None + + +def get_or_create_gms_client_memory_manager( + socket_path: str, + device: int, + mode: RequestedLockType, + *, + tag: str = "weights", + timeout_ms: Optional[int] = None, +) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]: + """Get existing memory manager or create a new one. + + Args: + socket_path: Unix socket path for the allocation server. + device: CUDA device index. + mode: RW for cold start, RO for import-only, RW_OR_RO for auto. + tag: Allocation tag for RW mode. + timeout_ms: Lock acquisition timeout (None = wait indefinitely). + + Returns: + (gms_client_memory_manager, pool) - pool is None for RO mode. + """ + global _gms_client_memory_manager, _mem_pool + + from gpu_memory_service.client.memory_manager import GMSClientMemoryManager + + if _gms_client_memory_manager is not None: + return _get_existing(mode) + + # Create new manager + gms_client_memory_manager = GMSClientMemoryManager( + socket_path, mode=mode, device=device, timeout_ms=timeout_ms + ) + _gms_client_memory_manager = gms_client_memory_manager + + if gms_client_memory_manager.mode == GrantedLockType.RW: + _mem_pool = _setup_mempool(gms_client_memory_manager, tag) + logger.info("[GMS] Created RW allocator (device=%d)", device) + return gms_client_memory_manager, _mem_pool + else: + logger.info("[GMS] Created RO allocator (device=%d)", device) + return gms_client_memory_manager, None + + +def _get_existing( + mode: RequestedLockType, +) -> Tuple["GMSClientMemoryManager", Optional["MemPool"]]: + """Return existing allocator if mode-compatible.""" + current = _gms_client_memory_manager.mode + + if mode == RequestedLockType.RW: + if current == GrantedLockType.RW: + return _gms_client_memory_manager, _mem_pool + raise RuntimeError(f"Cannot get RW allocator: existing is in {current} mode") + + if mode == RequestedLockType.RO: + if current == GrantedLockType.RO: + return _gms_client_memory_manager, None + raise RuntimeError( + f"Cannot get RO allocator: existing is in {current} mode. " + "Call manager.switch_to_read() first." + ) + + # RW_OR_RO: return whatever exists + pool = _mem_pool if current == GrantedLockType.RW else None + return _gms_client_memory_manager, pool + + +def _setup_mempool( + gms_client_memory_manager: "GMSClientMemoryManager", + tag: str, +) -> "MemPool": + """Set up PyTorch CUDAPluggableAllocator and MemPool.""" + global _pluggable_alloc + + from gpu_memory_service.client.torch.extensions import _allocator_ext as cumem + from torch.cuda import CUDAPluggableAllocator + from torch.cuda.memory import MemPool + + pluggable_alloc = CUDAPluggableAllocator(cumem.__file__, "my_malloc", "my_free") + pool = MemPool(allocator=pluggable_alloc.allocator()) + _pluggable_alloc = pluggable_alloc + + def malloc_cb(size: int, device: int, stream: int) -> int: + va = gms_client_memory_manager.allocate_and_map(int(size), tag=tag) + logger.debug("[GMS] malloc: va=0x%x size=%d", va, size) + return va + + def free_cb(ptr: int, size: int, device: int, stream: int) -> None: + logger.debug("[GMS] free: va=0x%x size=%d", ptr, size) + gms_client_memory_manager.free_mapping(int(ptr)) + + cumem.init_module(malloc_cb, free_cb) + return pool + + +def get_gms_client_memory_manager() -> Optional["GMSClientMemoryManager"]: + """Get the active GMS client memory manager, or None if not initialized.""" + return _gms_client_memory_manager diff --git a/lib/gpu_memory_service/client/torch/extensions/__init__.py b/lib/gpu_memory_service/client/torch/extensions/__init__.py new file mode 100644 index 00000000000..a3421a2d3bc --- /dev/null +++ b/lib/gpu_memory_service/client/torch/extensions/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service C++ extensions for PyTorch integration. + +These extensions are built at install time using setuptools. + +- _allocator_ext: CUDAPluggableAllocator backend (my_malloc/my_free) +""" + +# Built by setup.py build_ext --inplace +# Import will fail until extensions are built +try: + from gpu_memory_service.client.torch.extensions import _allocator_ext # noqa: F401 + from gpu_memory_service.client.torch.extensions._allocator_ext import * # noqa: F401, F403 +except ImportError: + _allocator_ext = None # type: ignore diff --git a/lib/gpu_memory_service/client/torch/extensions/allocator.cpp b/lib/gpu_memory_service/client/torch/extensions/allocator.cpp new file mode 100644 index 00000000000..4e273fbd7bc --- /dev/null +++ b/lib/gpu_memory_service/client/torch/extensions/allocator.cpp @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Minimal CUDAPluggableAllocator shim for GPU Memory Service. +// +// This extension provides the my_malloc/my_free function pointers required by +// PyTorch's CUDAPluggableAllocator. All actual CUDA VMM operations are delegated +// to Python callbacks which use cuda.bindings. +// +// Note: The stream parameter is unused because CUDA VMM operations (cuMemMap, +// cuMemUnmap) are synchronous and globally visible - they don't have per-stream +// semantics like cudaMallocAsync. We keep the parameter to match PyTorch's +// CUDAPluggableAllocator interface signature. + +#define PY_SSIZE_T_CLEAN +#include + +#include + +static PyObject* g_malloc_callback = nullptr; +static PyObject* g_free_callback = nullptr; + +extern "C" { + +void* +my_malloc(ssize_t size, int device, void* stream) +{ + if (!g_malloc_callback) { + return nullptr; + } + + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* args = Py_BuildValue("(niK)", size, device, (unsigned long long)stream); + PyObject* result = PyObject_CallObject(g_malloc_callback, args); + Py_DECREF(args); + + void* ptr = nullptr; + if (result && PyLong_Check(result)) { + ptr = (void*)PyLong_AsUnsignedLongLong(result); + } + Py_XDECREF(result); + + if (PyErr_Occurred()) { + PyErr_Print(); + } + + PyGILState_Release(gstate); + return ptr; +} + +void +my_free(void* ptr, ssize_t size, int device, void* stream) +{ + if (!g_free_callback) { + return; + } + + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* args = Py_BuildValue("(KniK)", (unsigned long long)ptr, size, device, (unsigned long long)stream); + PyObject* result = PyObject_CallObject(g_free_callback, args); + Py_DECREF(args); + Py_XDECREF(result); + + if (PyErr_Occurred()) { + PyErr_Print(); + } + + PyGILState_Release(gstate); +} + +static PyObject* +py_init_module(PyObject* self, PyObject* args) +{ + PyObject* malloc_cb = nullptr; + PyObject* free_cb = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_cb, &free_cb)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_cb) || !PyCallable_Check(free_cb)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + Py_XINCREF(malloc_cb); + Py_XINCREF(free_cb); + Py_XDECREF(g_malloc_callback); + Py_XDECREF(g_free_callback); + + g_malloc_callback = malloc_cb; + g_free_callback = free_cb; + + Py_RETURN_NONE; +} + +static PyMethodDef module_methods[] = { + {"init_module", py_init_module, METH_VARARGS, "Set malloc/free callbacks"}, {nullptr, nullptr, 0, nullptr}}; + +static struct PyModuleDef allocator_module = { + PyModuleDef_HEAD_INIT, "_allocator_ext", "CUDAPluggableAllocator shim for GPU Memory Service", -1, module_methods}; + +PyMODINIT_FUNC +PyInit__allocator_ext(void) +{ + return PyModule_Create(&allocator_module); +} + +} // extern "C" diff --git a/lib/gpu_memory_service/client/torch/module.py b/lib/gpu_memory_service/client/torch/module.py new file mode 100644 index 00000000000..ca7e2d0baa3 --- /dev/null +++ b/lib/gpu_memory_service/client/torch/module.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Module tensor operations for GPU Memory Service. + +This module provides module-level tensor operations: +- Module tensor iteration +- Tensor registration (write path) +- Tensor materialization (read path) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Iterator, Tuple + +import torch +from gpu_memory_service.client.torch.tensor import GMSTensorSpec, TensorMetadata + +if TYPE_CHECKING: + from gpu_memory_service.client.memory_manager import GMSClientMemoryManager + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Module Tensor Iteration +# ============================================================================= + + +def _iter_module_tensors( + module: torch.nn.Module, + prefix: str = "", +) -> Iterator[Tuple[str, torch.Tensor, str]]: + """Iterate over all CUDA tensors in a module tree. + + Yields (qualified_name, tensor, tensor_type) for: + - Parameters (tensor_type="parameter") + - Buffers (tensor_type="buffer") + - Other tensor attributes like _k_scale (tensor_type="tensor_attr") + + Args: + module: The nn.Module to iterate. + prefix: Prefix for qualified names (used in recursion). + + Yields: + (name, tensor, tensor_type) tuples for each CUDA tensor. + """ + # Parameters + for name, param in module._parameters.items(): + if param is not None and param.is_cuda: + qualified = f"{prefix}{name}" if prefix else name + yield (qualified, param, "parameter") + + # Buffers + for name, buf in module._buffers.items(): + if buf is not None and buf.is_cuda: + qualified = f"{prefix}{name}" if prefix else name + yield (qualified, buf, "buffer") + + # Other tensor attributes (not params/buffers/submodules) + skip = ( + set(module._parameters.keys()) + | set(module._buffers.keys()) + | set(module._modules.keys()) + ) + for attr_name in dir(module): + if attr_name in skip or attr_name.startswith("__"): + continue + try: + attr_val = getattr(module, attr_name, None) + except Exception: + continue + + if torch.is_tensor(attr_val) and attr_val.is_cuda: + qualified = f"{prefix}{attr_name}" if prefix else attr_name + yield (qualified, attr_val, "tensor_attr") + elif isinstance(attr_val, (list, tuple)) and attr_val: + if all(torch.is_tensor(x) and x.is_cuda for x in attr_val): + for i, x in enumerate(attr_val): + qualified = ( + f"{prefix}{attr_name}.{i}" if prefix else f"{attr_name}.{i}" + ) + yield (qualified, x, "tensor_attr") + + # Recurse into submodules + for name, submodule in module._modules.items(): + if submodule is not None: + subprefix = f"{prefix}{name}." if prefix else f"{name}." + yield from _iter_module_tensors(submodule, subprefix) + + +def _resolve_module_attr( + root: torch.nn.Module, qualified_name: str +) -> Tuple[torch.nn.Module, str]: + """Resolve a dotted name to (parent_module, leaf_attr). + + Handles ModuleList/Sequential (numeric indices) and ModuleDict (key access). + """ + parts = qualified_name.split(".") + mod = root + for p in parts[:-1]: + if hasattr(mod, p): + mod = getattr(mod, p) + elif hasattr(mod, "__getitem__"): + try: + mod = mod[int(p)] if p.isdigit() else mod[p] + except Exception: + raise AttributeError(f"Cannot resolve {p!r} in {qualified_name!r}") + else: + raise AttributeError(f"Cannot resolve {p!r} in {qualified_name!r}") + return mod, parts[-1] + + +# ============================================================================= +# Public API - Registration and Materialization +# ============================================================================= + + +def register_module_tensors( + gms_client_memory_manager: "GMSClientMemoryManager", + model: torch.nn.Module, +) -> None: + """Register all model tensors into the GMS metadata store. + + Args: + gms_client_memory_manager: GMS client memory manager in write mode. + model: PyTorch model to register. + """ + for name, tensor, tensor_type in _iter_module_tensors(model): + ptr = int(tensor.data_ptr()) + + # Find allocation containing this tensor + for va, mapping in gms_client_memory_manager.mappings.items(): + if va <= ptr < va + mapping.aligned_size: + offset = ptr - va + meta = TensorMetadata.from_tensor(tensor, tensor_type) + gms_client_memory_manager.metadata_put( + key=name, + allocation_id=mapping.allocation_id, + offset_bytes=offset, + value=meta.to_bytes(), + ) + break + else: + # No mapping matched - tensor pointer not in any GMS allocation + if tensor_type == "parameter": + # Parameters are model weights - must be in GMS allocations + raise RuntimeError(f"Tensor {name!r} not found in any GMS allocation") + # Buffers and tensor_attrs may be dynamically allocated (e.g., KV cache) + logger.debug( + "[GMS] Skipping %s %r - not in GMS allocations", tensor_type, name + ) + + +def materialize_module_from_gms( + gms_client_memory_manager: "GMSClientMemoryManager", + model: torch.nn.Module, + *, + device_index: int, +) -> None: + """Materialize model tensors from GMS. + + Args: + gms_client_memory_manager: GMS client memory manager in read mode. + model: Model to populate with tensors. + device_index: CUDA device index. + """ + specs = GMSTensorSpec.load_all(gms_client_memory_manager) + + for name, spec in specs.items(): + tensor = spec.materialize(gms_client_memory_manager, device_index) + mod, attr = _resolve_module_attr(model, name) + tensor_type = spec.meta.tensor_type + + # Tensor attrs and buffers: clone since they may be mutated + if tensor_type in ("tensor_attr", "buffer"): + if ( + tensor_type == "buffer" + and hasattr(mod, "_buffers") + and attr in mod._buffers + ): + mod._buffers[attr] = tensor.detach().clone() + else: + setattr(mod, attr, tensor.detach().clone()) + continue + + # Parameters: in-place update or replace meta tensors + if hasattr(mod, "_parameters") and attr in mod._parameters: + param = mod._parameters[attr] + if param is not None: + if param.shape != tensor.shape or param.dtype != tensor.dtype: + raise RuntimeError( + f"Shape/dtype mismatch for {name}: " + f"param={tuple(param.shape)}/{param.dtype}, " + f"gms={tuple(tensor.shape)}/{tensor.dtype}" + ) + if param.is_meta or param.device != tensor.device: + mod._parameters[attr] = torch.nn.Parameter( + tensor, requires_grad=param.requires_grad + ) + else: + param.data = tensor + continue + + # Fallback: set as attribute + setattr(mod, attr, tensor) + + # Check for meta tensors and warn + meta_tensors = [n for n, p in model.named_parameters() if p.is_meta] + meta_tensors += [n for n, b in model.named_buffers() if b.is_meta] + if meta_tensors: + logger.warning( + "[GMS] %d meta tensors not in metadata: %s", + len(meta_tensors), + meta_tensors[:10], + ) diff --git a/lib/gpu_memory_service/client/torch/tensor.py b/lib/gpu_memory_service/client/torch/tensor.py new file mode 100644 index 00000000000..a00e9abf965 --- /dev/null +++ b/lib/gpu_memory_service/client/torch/tensor.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tensor utilities for GPU Memory Service. + +This module provides low-level tensor functionality: +- Tensor creation from CUDA pointers +- Tensor metadata serialization/deserialization +- GMS tensor spec for metadata store entries +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Tuple + +import torch + +if TYPE_CHECKING: + from gpu_memory_service.client.memory_manager import GMSClientMemoryManager + + +# ============================================================================= +# Tensor Creation from CUDA Pointer +# ============================================================================= + + +def _tensor_from_pointer( + data_ptr: int, + shape: List[int], + stride: List[int], + dtype: torch.dtype, + device_index: int, +) -> torch.Tensor: + """Create a torch.Tensor from a raw CUDA pointer without copying data. + + Uses PyTorch's internal APIs to create a tensor that aliases existing + GPU memory. The tensor does NOT own the memory - the caller must ensure + the memory remains valid for the tensor's lifetime. + + Args: + data_ptr: CUDA device pointer (virtual address) to the tensor data. + shape: Tensor dimensions. + stride: Tensor strides (in elements, not bytes). + dtype: Tensor data type. + device_index: CUDA device index where the memory resides. + + Returns: + A tensor aliasing the specified GPU memory. + """ + device = torch.device("cuda", device_index) + + # Calculate storage size in bytes based on stride (handles non-contiguous tensors) + # For non-contiguous tensors, the memory footprint is larger than numel * element_size + element_size = torch.tensor([], dtype=dtype).element_size() + + if shape and stride: + if len(shape) != len(stride): + raise ValueError( + f"Shape and stride length mismatch: {len(shape)} vs {len(stride)}" + ) + # Maximum offset = sum of stride[i] * (shape[i] - 1) for all dimensions + max_offset = sum( + s * (d - 1) for s, d in zip(stride, shape, strict=True) if d > 0 + ) + required_elements = max_offset + 1 + else: + # Scalar tensor or empty tensor + required_elements = 1 + + storage_size_bytes = required_elements * element_size + + # Create storage from raw pointer (does not take ownership) + storage = torch._C._construct_storage_from_data_pointer( + data_ptr, device, storage_size_bytes + ) + + # Create tensor from storage with metadata + metadata = { + "size": torch.Size(shape), + "stride": stride, + "storage_offset": 0, + "dtype": dtype, + } + + return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, storage) + + +# ============================================================================= +# Tensor Metadata - serialization format for metadata store +# ============================================================================= + + +def _parse_dtype(dtype_str: str) -> torch.dtype: + """Parse dtype string (e.g., 'torch.float16') to torch.dtype.""" + s = str(dtype_str) + if s.startswith("torch."): + s = s.split(".", 1)[1] + dt = getattr(torch, s, None) + if not isinstance(dt, torch.dtype): + raise ValueError(f"Unknown dtype: {dtype_str!r}") + return dt + + +@dataclass(frozen=True) +class TensorMetadata: + """Metadata for a tensor stored in the GMS metadata store.""" + + shape: Tuple[int, ...] + dtype: torch.dtype + stride: Tuple[int, ...] + tensor_type: str = "parameter" # "parameter", "buffer", or "tensor_attr" + + @classmethod + def from_tensor( + cls, tensor: torch.Tensor, tensor_type: str = "parameter" + ) -> "TensorMetadata": + """Create TensorMetadata from an existing tensor.""" + return cls( + shape=tuple(tensor.shape), + dtype=tensor.dtype, + stride=tuple(int(s) for s in tensor.stride()), + tensor_type=tensor_type, + ) + + @classmethod + def from_bytes(cls, value: bytes) -> "TensorMetadata": + """Parse metadata from JSON bytes.""" + obj = json.loads(value.decode("utf-8")) + shape = tuple(int(x) for x in obj["shape"]) + dtype = _parse_dtype(obj["dtype"]) + + if "stride" in obj and obj["stride"] is not None: + stride = tuple(int(x) for x in obj["stride"]) + else: + # Legacy format: compute contiguous stride + stride = [] + acc = 1 + for d in reversed(shape): + stride.append(acc) + acc *= d + stride = tuple(reversed(stride)) if stride else () + + return cls( + shape=shape, + dtype=dtype, + stride=stride, + tensor_type=obj.get("tensor_type", "parameter"), + ) + + def to_bytes(self) -> bytes: + """Serialize to JSON bytes for metadata store.""" + return json.dumps( + { + "shape": list(self.shape), + "dtype": str(self.dtype), + "stride": list(self.stride), + "tensor_type": self.tensor_type, + }, + sort_keys=True, + ).encode("utf-8") + + +# ============================================================================= +# GMS Tensor Spec - metadata entry from store +# ============================================================================= + + +@dataclass(frozen=True) +class GMSTensorSpec: + """A tensor entry from the GMS metadata store.""" + + key: str + name: str + allocation_id: str + offset_bytes: int + meta: TensorMetadata + + @classmethod + def load_all( + cls, gms_client_memory_manager: "GMSClientMemoryManager" + ) -> Dict[str, "GMSTensorSpec"]: + """Load all metadata entries. + + Returns: + Mapping of tensor name -> GMSTensorSpec. + """ + specs: Dict[str, GMSTensorSpec] = {} + + for key in gms_client_memory_manager.metadata_list(): + got = gms_client_memory_manager.metadata_get(key) + if got is None: + raise RuntimeError(f"Metadata key disappeared: {key}") + + allocation_id, offset_bytes, value = got + + if key in specs: + raise RuntimeError(f"Duplicate tensor name: {key}") + + specs[key] = cls( + key=key, + name=key, + allocation_id=str(allocation_id), + offset_bytes=int(offset_bytes), + meta=TensorMetadata.from_bytes(value), + ) + + return specs + + def materialize( + self, + gms_client_memory_manager: "GMSClientMemoryManager", + device_index: int, + ) -> torch.Tensor: + """Create a tensor aliasing mapped CUDA memory.""" + base_va = gms_client_memory_manager.import_allocation(self.allocation_id) + ptr = int(base_va) + int(self.offset_bytes) + + return _tensor_from_pointer( + ptr, + list(self.meta.shape), + list(self.meta.stride), + self.meta.dtype, + device_index, + ) diff --git a/lib/gpu_memory_service/common/__init__.py b/lib/gpu_memory_service/common/__init__.py new file mode 100644 index 00000000000..52a7a9daf02 --- /dev/null +++ b/lib/gpu_memory_service/common/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/gpu_memory_service/common/cuda_vmm_utils.py b/lib/gpu_memory_service/common/cuda_vmm_utils.py new file mode 100644 index 00000000000..7d38b181d8f --- /dev/null +++ b/lib/gpu_memory_service/common/cuda_vmm_utils.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""CUDA Virtual Memory Management (VMM) utility functions. + +This module provides utility functions for CUDA driver API operations +used by both server (GMSServerMemoryManager) and client (GMSClientMemoryManager). +""" + +from cuda.bindings import driver as cuda + + +def check_cuda_result(result: cuda.CUresult, name: str) -> None: + """Check CUDA driver API result and raise on error. + + Args: + result: CUDA driver API return code (CUresult enum) + name: Operation name for error message + + Raises: + RuntimeError: If result is not CUDA_SUCCESS + """ + if result != cuda.CUresult.CUDA_SUCCESS: + err_result, err_str = cuda.cuGetErrorString(result) + if err_result == cuda.CUresult.CUDA_SUCCESS and err_str: + err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str) + else: + err_msg = str(result) + raise RuntimeError(f"{name}: {err_msg}") + + +def ensure_cuda_initialized() -> None: + """Ensure CUDA driver is initialized. + + Raises: + RuntimeError: If cuInit fails + """ + (result,) = cuda.cuInit(0) + check_cuda_result(result, "cuInit") + + +def get_allocation_granularity(device: int) -> int: + """Get VMM allocation granularity for a device. + + Args: + device: CUDA device index + + Returns: + Allocation granularity in bytes (typically 2 MiB) + """ + prop = cuda.CUmemAllocationProp() + prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + prop.location.id = device + prop.requestedHandleTypes = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ) + + result, granularity = cuda.cuMemGetAllocationGranularity( + prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM + ) + check_cuda_result(result, "cuMemGetAllocationGranularity") + return int(granularity) + + +def align_to_granularity(size: int, granularity: int) -> int: + """Align size up to VMM granularity. + + Args: + size: Size in bytes + granularity: Allocation granularity + + Returns: + Aligned size + """ + return ((size + granularity - 1) // granularity) * granularity diff --git a/lib/gpu_memory_service/common/protocol/__init__.py b/lib/gpu_memory_service/common/protocol/__init__.py new file mode 100644 index 00000000000..52a7a9daf02 --- /dev/null +++ b/lib/gpu_memory_service/common/protocol/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/gpu_memory_service/common/protocol/messages.py b/lib/gpu_memory_service/common/protocol/messages.py new file mode 100644 index 00000000000..fb8bf02b05b --- /dev/null +++ b/lib/gpu_memory_service/common/protocol/messages.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Message types for GPU Memory Service RPC protocol.""" + +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +import msgspec + + +class RequestedLockType(str, Enum): + """Lock type requested by client.""" + + RW = "rw" + RO = "ro" + RW_OR_RO = "rw_or_ro" + + +class GrantedLockType(str, Enum): + """Lock type actually granted by server.""" + + RW = "rw" + RO = "ro" + + +class HandshakeRequest(msgspec.Struct, tag="handshake_request"): + lock_type: RequestedLockType + timeout_ms: Optional[int] = None + + +class HandshakeResponse(msgspec.Struct, tag="handshake_response"): + success: bool + committed: bool + granted_lock_type: Optional[GrantedLockType] = None + + +class CommitRequest(msgspec.Struct, tag="commit_request"): + pass + + +class CommitResponse(msgspec.Struct, tag="commit_response"): + success: bool + + +class GetLockStateRequest(msgspec.Struct, tag="get_lock_state_request"): + pass + + +class GetLockStateResponse(msgspec.Struct, tag="get_lock_state_response"): + state: str # "EMPTY", "RW", "COMMITTED", "RO" + has_rw_session: bool + ro_session_count: int + waiting_writers: int + committed: bool + is_ready: bool + + +class GetAllocationStateRequest(msgspec.Struct, tag="get_allocation_state_request"): + pass + + +class GetAllocationStateResponse(msgspec.Struct, tag="get_allocation_state_response"): + allocation_count: int + total_bytes: int + + +class AllocateRequest(msgspec.Struct, tag="allocate_request"): + size: int + tag: str = "default" + + +class AllocateResponse(msgspec.Struct, tag="allocate_response"): + allocation_id: str + size: int + aligned_size: int + + +class ExportRequest(msgspec.Struct, tag="export_request"): + allocation_id: str + + +class GetAllocationRequest(msgspec.Struct, tag="get_allocation_request"): + allocation_id: str + + +class GetAllocationResponse(msgspec.Struct, tag="get_allocation_response"): + allocation_id: str + size: int + aligned_size: int + tag: str + + +class ListAllocationsRequest(msgspec.Struct, tag="list_allocations_request"): + tag: Optional[str] = None + + +class ListAllocationsResponse(msgspec.Struct, tag="list_allocations_response"): + allocations: List[Dict[str, Any]] = [] + + +class FreeRequest(msgspec.Struct, tag="free_request"): + allocation_id: str + + +class FreeResponse(msgspec.Struct, tag="free_response"): + success: bool + + +class ClearAllRequest(msgspec.Struct, tag="clear_all_request"): + pass + + +class ClearAllResponse(msgspec.Struct, tag="clear_all_response"): + cleared_count: int + + +class ErrorResponse(msgspec.Struct, tag="error_response"): + error: str + code: int = 0 + + +class MetadataPutRequest(msgspec.Struct, tag="metadata_put_request"): + key: str + allocation_id: str + offset_bytes: int + value: bytes + + +class MetadataPutResponse(msgspec.Struct, tag="metadata_put_response"): + success: bool + + +class MetadataGetRequest(msgspec.Struct, tag="metadata_get_request"): + key: str + + +class MetadataGetResponse(msgspec.Struct, tag="metadata_get_response"): + found: bool + allocation_id: Optional[str] = None + offset_bytes: Optional[int] = None + value: Optional[bytes] = None + + +class MetadataDeleteRequest(msgspec.Struct, tag="metadata_delete_request"): + key: str + + +class MetadataDeleteResponse(msgspec.Struct, tag="metadata_delete_response"): + deleted: bool + + +class MetadataListRequest(msgspec.Struct, tag="metadata_list_request"): + prefix: str = "" + + +class MetadataListResponse(msgspec.Struct, tag="metadata_list_response"): + keys: List[str] = [] + + +class GetStateHashRequest(msgspec.Struct, tag="get_memory_layout_hash_request"): + pass + + +class GetStateHashResponse(msgspec.Struct, tag="get_memory_layout_hash_response"): + memory_layout_hash: str # Hash of allocations + metadata, empty if not committed + + +Message = Union[ + HandshakeRequest, + HandshakeResponse, + CommitRequest, + CommitResponse, + GetLockStateRequest, + GetLockStateResponse, + GetAllocationStateRequest, + GetAllocationStateResponse, + AllocateRequest, + AllocateResponse, + ExportRequest, + GetAllocationRequest, + GetAllocationResponse, + ListAllocationsRequest, + ListAllocationsResponse, + FreeRequest, + FreeResponse, + ClearAllRequest, + ClearAllResponse, + ErrorResponse, + MetadataPutRequest, + MetadataPutResponse, + MetadataGetRequest, + MetadataGetResponse, + MetadataDeleteRequest, + MetadataDeleteResponse, + MetadataListRequest, + MetadataListResponse, + GetStateHashRequest, + GetStateHashResponse, +] + +_encoder = msgspec.msgpack.Encoder() +_decoder = msgspec.msgpack.Decoder(Message) + + +def encode_message(msg: Message) -> bytes: + return _encoder.encode(msg) + + +def decode_message(data: bytes) -> Message: + return _decoder.decode(data) diff --git a/lib/gpu_memory_service/common/protocol/wire.py b/lib/gpu_memory_service/common/protocol/wire.py new file mode 100644 index 00000000000..0d3b25970fb --- /dev/null +++ b/lib/gpu_memory_service/common/protocol/wire.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Wire protocol for length-prefixed messages with optional FD passing.""" + +import asyncio +import os +import socket +import struct +from typing import Optional, Tuple + +from .messages import Message, decode_message, encode_message + +HEADER_SIZE = 4 # 4-byte big-endian length prefix + + +def _frame_message(msg: Message) -> bytes: + """Encode and frame a message with length prefix.""" + data = encode_message(msg) + return struct.pack("!I", len(data)) + data + + +def _try_extract_message( + recv_buffer: bytearray, +) -> Tuple[Optional[Message], bytearray, int]: + """Try to extract a complete message from buffer. + + Returns (message, remaining_buffer, bytes_needed). + """ + if len(recv_buffer) < HEADER_SIZE: + return None, recv_buffer, HEADER_SIZE - len(recv_buffer) + + length = struct.unpack("!I", bytes(recv_buffer[:HEADER_SIZE]))[0] + total_needed = HEADER_SIZE + length + + if len(recv_buffer) < total_needed: + return None, recv_buffer, total_needed - len(recv_buffer) + + msg_data = bytes(recv_buffer[HEADER_SIZE:total_needed]) + remaining = bytearray(recv_buffer[total_needed:]) + return decode_message(msg_data), remaining, 0 + + +# ==================== Async (for server) ==================== + + +async def send_message(writer, msg: Message, fd: int = -1) -> None: + """Send a length-prefixed message with optional FD via SCM_RIGHTS.""" + frame = _frame_message(msg) + + if fd >= 0: + transport_sock = writer.get_extra_info("socket") + if transport_sock is None: + raise RuntimeError("Cannot get socket from transport for FD passing") + + def do_send_fd(): + raw_fd = transport_sock.fileno() + dup_fd = os.dup(raw_fd) + try: + sock = socket.socket(fileno=dup_fd) + try: + sock.setblocking(True) + socket.send_fds(sock, [frame], [fd]) + finally: + sock.detach() + except Exception: + os.close(dup_fd) + raise + + await asyncio.get_running_loop().run_in_executor(None, do_send_fd) + else: + writer.write(frame) + await writer.drain() + + +async def recv_message( + reader, recv_buffer: Optional[bytearray] = None, raw_sock=None +) -> Tuple[Optional[Message], int, bytearray]: + """Receive a length-prefixed message with optional FD. + + Returns (message, fd, remaining_buffer). fd is -1 if none sent. + """ + if recv_buffer is None: + recv_buffer = bytearray() + + # Check if complete message already in buffer + msg, remaining, _ = _try_extract_message(recv_buffer) + if msg is not None: + return msg, -1, remaining + + loop = asyncio.get_running_loop() + fd = -1 + + # Receive more data + if raw_sock is not None: + raw_msg, fds, _flags, _addr = await loop.run_in_executor( + None, lambda: socket.recv_fds(raw_sock, 65536, 1) + ) + if not raw_msg: + raise ConnectionResetError("Connection closed") + recv_buffer.extend(raw_msg) + fd = fds[0] if fds else -1 + else: + chunk = await reader.read(65536) + if not chunk: + raise ConnectionResetError("Connection closed") + recv_buffer.extend(chunk) + + # Try to extract message, read more if needed + msg, remaining, bytes_needed = _try_extract_message(recv_buffer) + while msg is None and bytes_needed > 0: + if raw_sock is not None: + # Continue reading from raw socket to avoid buffer inconsistency + chunk = await loop.run_in_executor( + None, lambda n=bytes_needed: raw_sock.recv(n) + ) + else: + chunk = await reader.read(bytes_needed) + if not chunk: + raise ConnectionResetError("Connection closed") + remaining.extend(chunk) + msg, remaining, bytes_needed = _try_extract_message(remaining) + + return msg, fd, remaining + + +# ==================== Sync (for client) ==================== + + +def send_message_sync(sock, msg: Message, fd: int = -1) -> None: + """Send a length-prefixed message with optional FD via SCM_RIGHTS.""" + frame = _frame_message(msg) + if fd >= 0: + socket.send_fds(sock, [frame], [fd]) + else: + sock.sendall(frame) + + +def recv_message_sync( + sock, recv_buffer: Optional[bytearray] = None +) -> Tuple[Optional[Message], int, bytearray]: + """Receive a length-prefixed message with optional FD. + + Returns (message, fd, remaining_buffer). fd is -1 if none sent. + """ + if recv_buffer is None: + recv_buffer = bytearray() + + # Check if complete message already in buffer + msg, remaining, _ = _try_extract_message(recv_buffer) + if msg is not None: + return msg, -1, remaining + + # Receive more data (with potential FD) + raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1) + if not raw_msg: + raise ConnectionResetError("Connection closed") + recv_buffer.extend(raw_msg) + fd = fds[0] if fds else -1 + + # Try to extract message, read more if needed + msg, remaining, bytes_needed = _try_extract_message(recv_buffer) + while msg is None and bytes_needed > 0: + chunk = sock.recv(bytes_needed) + if not chunk: + raise ConnectionResetError("Connection closed") + remaining.extend(chunk) + msg, remaining, bytes_needed = _try_extract_message(remaining) + + return msg, fd, remaining diff --git a/lib/gpu_memory_service/common/types.py b/lib/gpu_memory_service/common/types.py new file mode 100644 index 00000000000..3c4afd4069c --- /dev/null +++ b/lib/gpu_memory_service/common/types.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared types for GPU Memory Service.""" + +from dataclasses import dataclass +from enum import Enum, auto + +from gpu_memory_service.common.protocol.messages import ( + AllocateRequest, + ClearAllRequest, + CommitRequest, + ExportRequest, + FreeRequest, + GetAllocationRequest, + GetAllocationStateRequest, + GetLockStateRequest, + GetStateHashRequest, + GrantedLockType, + ListAllocationsRequest, + MetadataDeleteRequest, + MetadataGetRequest, + MetadataListRequest, + MetadataPutRequest, + RequestedLockType, +) + +# Re-export lock types for convenience +__all__ = [ + "GrantedLockType", + "RequestedLockType", + "ServerState", + "StateEvent", + "StateSnapshot", + "derive_state", + "RW_REQUIRED", + "RO_ALLOWED", + "RW_ALLOWED", +] + + +class ServerState(str, Enum): + """Server state - derived from actual connections.""" + + EMPTY = "EMPTY" + RW = "RW" + COMMITTED = "COMMITTED" + RO = "RO" + + +class StateEvent(Enum): + """Events that trigger state transitions.""" + + RW_CONNECT = auto() + RW_COMMIT = auto() + RW_ABORT = auto() + RO_CONNECT = auto() + RO_DISCONNECT = auto() + + +@dataclass +class StateSnapshot: + """Current server state snapshot.""" + + state: ServerState + has_rw: bool + ro_count: int + waiting_writers: int + committed: bool + + @property + def is_ready(self) -> bool: + """Ready = committed and no RW connection.""" + return self.committed and not self.has_rw + + +def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState: + """Derive server state from connection info.""" + if has_rw: + return ServerState.RW + if ro_count > 0: + return ServerState.RO + if committed: + return ServerState.COMMITTED + return ServerState.EMPTY + + +# Permission sets: which message types require which connection mode +RW_REQUIRED: frozenset[type] = frozenset( + { + AllocateRequest, + FreeRequest, + ClearAllRequest, + MetadataPutRequest, + MetadataDeleteRequest, + CommitRequest, + } +) + +RO_ALLOWED: frozenset[type] = frozenset( + { + ExportRequest, + GetAllocationRequest, + ListAllocationsRequest, + MetadataGetRequest, + MetadataListRequest, + GetLockStateRequest, + GetAllocationStateRequest, + GetStateHashRequest, + } +) + +RW_ALLOWED: frozenset[type] = RW_REQUIRED | RO_ALLOWED diff --git a/lib/gpu_memory_service/pyproject.toml b/lib/gpu_memory_service/pyproject.toml new file mode 100644 index 00000000000..0c412f25962 --- /dev/null +++ b/lib/gpu_memory_service/pyproject.toml @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "gpu-memory-service" +version = "0.8.0" +description = "GPU Memory Service for Dynamo - CUDA VMM-based GPU memory allocation and sharing" +readme = "README.md" +authors = [ + { name = "NVIDIA Inc.", email = "sw-dl-dynamo@nvidia.com" }, +] +license = { text = "Apache-2.0" } +requires-python = ">=3.10" +dependencies = [ + "msgspec>=0.18.0", + "uvloop>=0.21.0", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Information Technology", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Operating System :: POSIX :: Linux", +] +keywords = ["llm", "genai", "inference", "nvidia", "gpu", "memory", "dynamo"] + +[project.optional-dependencies] +test = [ + "pytest>=8.3.4", + "pytest-asyncio", +] diff --git a/lib/gpu_memory_service/server/__init__.py b/lib/gpu_memory_service/server/__init__.py new file mode 100644 index 00000000000..76375989605 --- /dev/null +++ b/lib/gpu_memory_service/server/__init__.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPU Memory Service server components.""" + +from gpu_memory_service.common.types import ( + GrantedLockType, + RequestedLockType, + ServerState, + StateSnapshot, +) +from gpu_memory_service.server.handler import MetadataEntry, RequestHandler +from gpu_memory_service.server.locking import Connection, GlobalLockFSM +from gpu_memory_service.server.memory_manager import ( + AllocationInfo, + AllocationNotFoundError, + GMSServerMemoryManager, +) +from gpu_memory_service.server.rpc import GMSRPCServer + +__all__ = [ + "GMSRPCServer", + "GMSServerMemoryManager", + "AllocationInfo", + "AllocationNotFoundError", + "MetadataEntry", + "Connection", + "GrantedLockType", + "RequestedLockType", + "RequestHandler", + "ServerState", + "GlobalLockFSM", + "StateSnapshot", +] diff --git a/lib/gpu_memory_service/server/handler.py b/lib/gpu_memory_service/server/handler.py new file mode 100644 index 00000000000..a21572bb88d --- /dev/null +++ b/lib/gpu_memory_service/server/handler.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Request handlers for GPU Memory Service.""" + +import hashlib +import logging +from dataclasses import dataclass + +from gpu_memory_service.common.protocol.messages import ( + AllocateRequest, + AllocateResponse, + ClearAllResponse, + FreeRequest, + FreeResponse, + GetAllocationRequest, + GetAllocationResponse, + GetAllocationStateResponse, + GetLockStateResponse, + GetStateHashResponse, + ListAllocationsRequest, + ListAllocationsResponse, + MetadataDeleteRequest, + MetadataDeleteResponse, + MetadataGetRequest, + MetadataGetResponse, + MetadataListRequest, + MetadataListResponse, + MetadataPutRequest, + MetadataPutResponse, +) +from gpu_memory_service.common.types import derive_state + +from .memory_manager import AllocationNotFoundError, GMSServerMemoryManager + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class MetadataEntry: + allocation_id: str + offset_bytes: int + value: bytes + + +class RequestHandler: + """Handles allocation and metadata requests.""" + + def __init__(self, device: int = 0): + self._memory_manager = GMSServerMemoryManager(device) + self._metadata: dict[str, MetadataEntry] = {} + self._memory_layout_hash: str = ( + "" # Hash of allocations + metadata, computed on commit + ) + logger.info(f"RequestHandler initialized: device={device}") + + @property + def granularity(self) -> int: + return self._memory_manager.granularity + + def on_rw_abort(self) -> None: + """Called when RW connection closes without commit.""" + logger.warning("RW aborted; clearing allocations and metadata") + self._memory_manager.clear_all() + self._metadata.clear() + self._memory_layout_hash = "" + + def on_commit(self) -> None: + """Called when RW connection commits. Computes state hash.""" + self._memory_layout_hash = self._compute_memory_layout_hash() + logger.info(f"Committed with state hash: {self._memory_layout_hash[:16]}...") + + def _compute_memory_layout_hash(self) -> str: + """Compute hash of current allocations + metadata.""" + h = hashlib.sha256() + # Hash allocations (sorted by ID for determinism) + for info in sorted( + self._memory_manager.list_allocations(), key=lambda x: x.allocation_id + ): + h.update( + f"{info.allocation_id}:{info.size}:{info.aligned_size}:{info.tag}".encode() + ) + # Hash metadata (sorted by key for determinism) + for key in sorted(self._metadata.keys()): + entry = self._metadata[key] + h.update(f"{key}:{entry.allocation_id}:{entry.offset_bytes}:".encode()) + h.update(entry.value) + return h.hexdigest() + + def on_shutdown(self) -> None: + """Called on server shutdown.""" + if self._memory_manager.allocation_count > 0: + count = self._memory_manager.clear_all() + self._metadata.clear() + logger.info(f"Released {count} GPU allocations during shutdown") + + # ==================== State Queries ==================== + + def handle_get_lock_state( + self, + has_rw: bool, + ro_count: int, + waiting_writers: int, + committed: bool, + ) -> GetLockStateResponse: + """Get lock/session state.""" + state = derive_state(has_rw, ro_count, committed) + return GetLockStateResponse( + state=state.value, + has_rw_session=has_rw, + ro_session_count=ro_count, + waiting_writers=waiting_writers, + committed=committed, + is_ready=committed and not has_rw, + ) + + def handle_get_allocation_state(self) -> GetAllocationStateResponse: + """Get allocation state.""" + return GetAllocationStateResponse( + allocation_count=self._memory_manager.allocation_count, + total_bytes=self._memory_manager.total_bytes, + ) + + # ==================== Allocation Operations ==================== + + def handle_allocate(self, req: AllocateRequest) -> AllocateResponse: + """Create physical memory allocation. + + Requires RW connection (enforced by server). + """ + info = self._memory_manager.allocate(req.size, req.tag) + return AllocateResponse( + allocation_id=info.allocation_id, + size=info.size, + aligned_size=info.aligned_size, + ) + + def handle_export(self, allocation_id: str) -> tuple[GetAllocationResponse, int]: + """Export allocation as POSIX FD. + + Returns (response, fd). Caller must close fd after sending. + """ + fd = self._memory_manager.export_fd(allocation_id) + info = self._memory_manager.get_allocation(allocation_id) + response = GetAllocationResponse( + allocation_id=info.allocation_id, + size=info.size, + aligned_size=info.aligned_size, + tag=info.tag, + ) + return response, fd + + def handle_get_allocation(self, req: GetAllocationRequest) -> GetAllocationResponse: + """Get allocation info.""" + try: + info = self._memory_manager.get_allocation(req.allocation_id) + return GetAllocationResponse( + allocation_id=info.allocation_id, + size=info.size, + aligned_size=info.aligned_size, + tag=info.tag, + ) + except AllocationNotFoundError: + raise ValueError(f"Unknown allocation: {req.allocation_id}") from None + + def handle_list_allocations( + self, req: ListAllocationsRequest + ) -> ListAllocationsResponse: + """List all allocations.""" + allocations = self._memory_manager.list_allocations(req.tag) + result = [ + { + "allocation_id": info.allocation_id, + "size": info.size, + "aligned_size": info.aligned_size, + "tag": info.tag, + } + for info in allocations + ] + return ListAllocationsResponse(allocations=result) + + def handle_free(self, req: FreeRequest) -> FreeResponse: + """Free single allocation. + + Requires RW connection (enforced by server). + """ + success = self._memory_manager.free(req.allocation_id) + return FreeResponse(success=success) + + def handle_clear_all(self) -> ClearAllResponse: + """Clear all allocations and metadata. + + Requires RW connection (enforced by server). + """ + count = self._memory_manager.clear_all() + self._metadata.clear() + return ClearAllResponse(cleared_count=count) + + # ==================== Metadata Operations ==================== + + def handle_metadata_put(self, req: MetadataPutRequest) -> MetadataPutResponse: + self._metadata[req.key] = MetadataEntry( + req.allocation_id, req.offset_bytes, req.value + ) + return MetadataPutResponse(success=True) + + def handle_metadata_get(self, req: MetadataGetRequest) -> MetadataGetResponse: + entry = self._metadata.get(req.key) + if entry is None: + return MetadataGetResponse(found=False) + return MetadataGetResponse( + found=True, + allocation_id=entry.allocation_id, + offset_bytes=entry.offset_bytes, + value=entry.value, + ) + + def handle_metadata_delete( + self, req: MetadataDeleteRequest + ) -> MetadataDeleteResponse: + return MetadataDeleteResponse( + deleted=self._metadata.pop(req.key, None) is not None + ) + + def handle_metadata_list(self, req: MetadataListRequest) -> MetadataListResponse: + keys = ( + [k for k in self._metadata if k.startswith(req.prefix)] + if req.prefix + else list(self._metadata) + ) + return MetadataListResponse(keys=sorted(keys)) + + def handle_get_memory_layout_hash(self) -> GetStateHashResponse: + return GetStateHashResponse(memory_layout_hash=self._memory_layout_hash) diff --git a/lib/gpu_memory_service/server/locking.py b/lib/gpu_memory_service/server/locking.py new file mode 100644 index 00000000000..02087d15833 --- /dev/null +++ b/lib/gpu_memory_service/server/locking.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Connection and state machine for GPU Memory Service. + +This module handles: +- Connection: Represents an active client connection +- GlobalLockFSM: Explicit state transitions with validated permissions + +State Diagram: + + EMPTY ──RW_CONNECT──► RW ──RW_COMMIT──► COMMITTED + ▲ │ │ + │ │ │ + └───RW_ABORT─────────┘ │ + ▼ + COMMITTED ◄──RO_DISCONNECT (last)── RO ◄──RO_CONNECT + │ ▲ + │ │ + └──RO_CONNECT──────┘ + └──RO_DISCONNECT───┘ (not last) +""" + +from __future__ import annotations + +import asyncio +import logging +import socket +from dataclasses import dataclass, field +from typing import Callable, Optional, Set + +from gpu_memory_service.common.types import ( + RO_ALLOWED, + RW_ALLOWED, + RW_REQUIRED, + GrantedLockType, + ServerState, + StateEvent, +) + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Connection +# ============================================================================= + + +@dataclass(eq=False) +class Connection: + """Represents an active connection. + + The existence of Connection objects IS the state - we don't track + sessions separately. When a Connection is removed, the lock is released. + + Note: eq=False disables auto-generated __eq__ so we can use default + object identity for equality and add __hash__ for use in sets. + """ + + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + mode: GrantedLockType + session_id: str + recv_buffer: bytearray = field(default_factory=bytearray) + + def __hash__(self) -> int: + """Hash based on session_id (immutable identifier).""" + return hash(self.session_id) + + @property + def raw_socket(self) -> socket.socket: + """Get underlying socket for FD passing.""" + return self.writer.get_extra_info("socket") + + async def close(self) -> None: + """Close the connection.""" + self.writer.close() + try: + await self.writer.wait_closed() + except Exception: + pass + + +# ============================================================================= +# State Machine +# ============================================================================= + + +class InvalidTransition(Exception): + """Raised when an invalid state transition is attempted.""" + + pass + + +class OperationNotAllowed(Exception): + """Raised when an operation is not allowed in the current state/mode.""" + + pass + + +@dataclass(frozen=True) +class Transition: + """A valid state transition. + + Attributes: + from_states: Set of states this transition can originate from + event: The event that triggers this transition + to_state: The resulting state (or None if conditional) + condition: Optional condition function for conditional transitions + """ + + from_states: frozenset[ServerState] + event: StateEvent + to_state: Optional[ServerState] + condition: Optional[str] = None # Name of condition method + + +# Transition table - the single source of truth for valid state transitions +TRANSITIONS: list[Transition] = [ + # From EMPTY or COMMITTED: RW can connect + # Writer acquires exclusive lock + Transition( + from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}), + event=StateEvent.RW_CONNECT, + to_state=ServerState.RW, + ), + # From RW: commit publishes and transitions to COMMITTED + # Writer publishes and releases lock + Transition( + from_states=frozenset({ServerState.RW}), + event=StateEvent.RW_COMMIT, + to_state=ServerState.COMMITTED, + ), + # From RW: abort (disconnect without commit) transitions to EMPTY + # Writer aborts, state invalidated + Transition( + from_states=frozenset({ServerState.RW}), + event=StateEvent.RW_ABORT, + to_state=ServerState.EMPTY, + ), + # From COMMITTED or RO: RO can connect + # Reader acquires shared lock + Transition( + from_states=frozenset({ServerState.COMMITTED, ServerState.RO}), + event=StateEvent.RO_CONNECT, + to_state=ServerState.RO, + ), + # From RO: reader disconnect (not last) stays in RO + # Reader leaves, others remain + Transition( + from_states=frozenset({ServerState.RO}), + event=StateEvent.RO_DISCONNECT, + to_state=ServerState.RO, + condition="has_remaining_readers", + ), + # From RO: last reader disconnect transitions to COMMITTED + # Last reader leaves + Transition( + from_states=frozenset({ServerState.RO}), + event=StateEvent.RO_DISCONNECT, + to_state=ServerState.COMMITTED, + condition="is_last_reader", + ), +] + + +@dataclass +class TransitionRecord: + """Record of a state transition for debugging/auditing.""" + + from_state: ServerState + event: StateEvent + to_state: ServerState + session_id: Optional[str] = None + + +class GlobalLockFSM: + """Explicit state machine for GPU Memory Service. + + State is DERIVED from actual connection objects: + - _rw_conn: The active RW connection (or None) + - _ro_conns: Set of active RO connections + - _committed: Whether allocations have been committed + + All state mutations happen through explicit transitions. + """ + + def __init__(self, on_rw_abort: Optional[Callable[[], None]] = None): + """Initialize the state machine. + + Args: + on_rw_abort: Callback invoked when RW aborts (for cleanup) + """ + # Connection state - THIS IS THE SOURCE OF TRUTH + self._rw_conn: Optional[Connection] = None + self._ro_conns: Set[Connection] = set() + self._committed: bool = False + + # Callback for RW abort cleanup + self._on_rw_abort = on_rw_abort + + # Transition history for debugging + self._transition_log: list[TransitionRecord] = [] + + # ==================== State Properties ==================== + + @property + def state(self) -> ServerState: + """Derive current state from connection objects.""" + if self._rw_conn is not None: + return ServerState.RW + if len(self._ro_conns) > 0: + return ServerState.RO + if self._committed: + return ServerState.COMMITTED + return ServerState.EMPTY + + @property + def rw_conn(self) -> Optional[Connection]: + """The active RW connection, if any.""" + return self._rw_conn + + @property + def ro_conns(self) -> Set[Connection]: + """Set of active RO connections.""" + return self._ro_conns + + @property + def ro_count(self) -> int: + """Number of active RO connections.""" + return len(self._ro_conns) + + @property + def committed(self) -> bool: + """Whether allocations have been committed.""" + return self._committed + + @property + def transition_log(self) -> list[TransitionRecord]: + """History of state transitions.""" + return self._transition_log + + # ==================== Transition Conditions ==================== + + def _has_remaining_readers(self, conn: Connection) -> bool: + """Check if there are readers remaining after removing conn.""" + return len(self._ro_conns) > 1 or conn not in self._ro_conns + + def _is_last_reader(self, conn: Connection) -> bool: + """Check if conn is the last reader.""" + return len(self._ro_conns) == 1 and conn in self._ro_conns + + def _check_condition(self, condition: Optional[str], conn: Connection) -> bool: + """Evaluate a named condition.""" + if condition is None: + return True + if condition == "has_remaining_readers": + return self._has_remaining_readers(conn) + if condition == "is_last_reader": + return self._is_last_reader(conn) + raise ValueError(f"Unknown condition: {condition}") + + # ==================== State Transitions ==================== + + def _find_transition( + self, from_state: ServerState, event: StateEvent, conn: Connection + ) -> Optional[Transition]: + """Find the applicable transition for the given event.""" + for t in TRANSITIONS: + if from_state not in t.from_states: + continue + if t.event != event: + continue + if not self._check_condition(t.condition, conn): + continue + return t + return None + + def _apply_event(self, event: StateEvent, conn: Connection) -> None: + """Mutate internal state based on event.""" + match event: + case StateEvent.RW_CONNECT: + self._rw_conn = conn + self._committed = False # Invalidate on RW connect + case StateEvent.RW_COMMIT: + self._committed = True + self._rw_conn = None + case StateEvent.RW_ABORT: + self._rw_conn = None + if self._on_rw_abort: + self._on_rw_abort() + case StateEvent.RO_CONNECT: + self._ro_conns.add(conn) + case StateEvent.RO_DISCONNECT: + self._ro_conns.discard(conn) + + def transition(self, event: StateEvent, conn: Connection) -> ServerState: + """Execute a state transition. + + Args: + event: The triggering event + conn: The connection involved in the transition + + Returns: + The new state after the transition + + Raises: + InvalidTransition: If the transition is not valid from current state + """ + from_state = self.state + session_id = conn.session_id if conn else None + + # Find valid transition + trans = self._find_transition(from_state, event, conn) + if trans is None: + raise InvalidTransition( + f"No transition for {event.name} from state {from_state.name} " + f"(session={session_id})" + ) + + # Apply the transition + self._apply_event(event, conn) + to_state = self.state + + # Validate we ended up in expected state + if trans.to_state is not None and to_state != trans.to_state: + raise InvalidTransition( + f"Transition mismatch: expected {trans.to_state.name}, " + f"got {to_state.name}" + ) + + # Record transition + record = TransitionRecord(from_state, event, to_state, session_id) + self._transition_log.append(record) + + logger.info( + f"State transition: {from_state.name} --{event.name}--> {to_state.name} " + f"(session={session_id})" + ) + + return to_state + + # ==================== Operation Permissions ==================== + + def check_operation(self, msg_type: type, conn: Connection) -> None: + """Check if a request type is allowed for the given connection. + + Args: + msg_type: The request message type (e.g., AllocateRequest) + conn: The connection attempting the operation + + Raises: + OperationNotAllowed: If the operation is not permitted + """ + current_state = self.state + + # Determine allowed operations based on state + if current_state == ServerState.RW: + allowed = RW_ALLOWED + elif current_state == ServerState.RO: + allowed = RO_ALLOWED + else: + allowed = frozenset() # EMPTY and COMMITTED have no connections + + if msg_type not in allowed: + raise OperationNotAllowed( + f"{msg_type.__name__} not allowed in state {current_state.name}" + ) + + # Check connection mode + if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW: + raise OperationNotAllowed( + f"{msg_type.__name__} requires RW connection, " + f"but connection is {conn.mode.value}" + ) + + # ==================== Lock Acquisition Predicates ==================== + + def can_acquire_rw(self) -> bool: + """Check if RW lock can be acquired now. + + RW can only be acquired if: + - No current RW holder + - No RO holders + + Note: This allows RW from COMMITTED state (for explicit reload). + For rw_or_ro mode, callers should also check `committed` to prefer RO. + """ + return self._rw_conn is None and len(self._ro_conns) == 0 + + def can_acquire_ro(self, waiting_writers: int) -> bool: + """Check if RO lock can be acquired now. + + Args: + waiting_writers: Number of writers waiting for the lock + """ + return self._rw_conn is None and waiting_writers == 0 and self._committed diff --git a/lib/gpu_memory_service/server/memory_manager.py b/lib/gpu_memory_service/server/memory_manager.py new file mode 100644 index 00000000000..d329c80b74f --- /dev/null +++ b/lib/gpu_memory_service/server/memory_manager.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""CUDA VMM allocation manager - pure business logic, no threading/transport. + +This module contains the GMSServerMemoryManager class which handles physical GPU memory +allocations via CUDA Virtual Memory Management (VMM) API. It creates shareable +memory without mapping it locally (no CUDA context needed on the server). + +The GMSServerMemoryManager is NOT thread-safe. Callers must provide external +synchronization (e.g., via LockManager ensuring single-writer access). +""" + +import logging +import time +from dataclasses import dataclass +from typing import Dict, List, Optional +from uuid import uuid4 + +from cuda.bindings import driver as cuda +from gpu_memory_service.common.cuda_vmm_utils import ( + align_to_granularity, + check_cuda_result, + ensure_cuda_initialized, + get_allocation_granularity, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class AllocationInfo: + """Information about a single GPU memory allocation. + + Attributes: + allocation_id: Unique identifier for this allocation + size: Requested size in bytes + aligned_size: Actual size after alignment to VMM granularity + handle: CUmemGenericAllocationHandle value + tag: User-provided tag for grouping allocations + created_at: Timestamp when allocation was created + """ + + allocation_id: str + size: int + aligned_size: int + handle: int + tag: str + created_at: float + + +class AllocationNotFoundError(Exception): + """Raised when an allocation_id doesn't exist.""" + + pass + + +class GMSServerMemoryManager: + """GPU Memory Service server-side memory manager. + + Manages CUDA VMM physical memory allocations. This class handles the core + memory operations using CUDA Virtual Memory Management (VMM) API. It creates + physical allocations that can be exported as POSIX file descriptors for + sharing with other processes. + + Key design points: + - No VA mapping: The memory manager never maps memory to virtual addresses, + so it doesn't create a CUDA context. This allows it to survive GPU + driver failures. + - NOT thread-safe: Callers must provide external synchronization. + The GlobalLockFSM's RW/RO semantics ensure single-writer access. + """ + + def __init__(self, device: int = 0): + self._device = device + self._allocations: Dict[str, AllocationInfo] = {} + ensure_cuda_initialized() + self._granularity = get_allocation_granularity(device) + logger.info( + f"GMSServerMemoryManager initialized: device={device}, granularity={self._granularity}" + ) + + @property + def device(self) -> int: + return self._device + + @property + def granularity(self) -> int: + return self._granularity + + @property + def allocation_count(self) -> int: + return len(self._allocations) + + @property + def total_bytes(self) -> int: + return sum(info.aligned_size for info in self._allocations.values()) + + def _get(self, allocation_id: str) -> AllocationInfo: + info = self._allocations.get(allocation_id) + if info is None: + raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}") + return info + + def _release(self, info: AllocationInfo) -> None: + (result,) = cuda.cuMemRelease(info.handle) + if result != cuda.CUresult.CUDA_SUCCESS: + logger.warning(f"cuMemRelease failed for {info.allocation_id}: {result}") + + def allocate(self, size: int, tag: str = "default") -> AllocationInfo: + """Create a physical memory allocation (no VA mapping). + + Uses cuMemCreate to allocate physical GPU memory that can be exported + as a file descriptor for sharing with other processes. + + Args: + size: Requested size in bytes (will be aligned up to granularity) + tag: Tag for grouping allocations (e.g., "weights", "kv_cache") + + Returns: + AllocationInfo with allocation_id, aligned_size, handle + + Raises: + RuntimeError: If CUDA allocation fails + """ + aligned_size = align_to_granularity(size, self._granularity) + + prop = cuda.CUmemAllocationProp() + prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + prop.location.id = self._device + prop.requestedHandleTypes = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ) + + result, handle = cuda.cuMemCreate(aligned_size, prop, 0) + check_cuda_result(result, "cuMemCreate") + + info = AllocationInfo( + allocation_id=str(uuid4()), + size=size, + aligned_size=aligned_size, + handle=int(handle), + tag=tag, + created_at=time.time(), + ) + self._allocations[info.allocation_id] = info + logger.debug( + f"Allocated {info.allocation_id}: size={size}, aligned={aligned_size}, tag={tag}" + ) + return info + + def export_fd(self, allocation_id: str) -> int: + """Export allocation as POSIX FD for SCM_RIGHTS transfer. + + The returned file descriptor can be sent to another process via + Unix domain socket SCM_RIGHTS. The receiving process can then + import it using cuMemImportFromShareableHandle. + + IMPORTANT: The caller MUST close the returned FD after sendmsg() + to avoid file descriptor leaks. + + Args: + allocation_id: ID of allocation to export + + Returns: + File descriptor (caller owns, must close after sending) + + Raises: + AllocationNotFoundError: If allocation_id doesn't exist + RuntimeError: If CUDA export fails + """ + info = self._get(allocation_id) + result, fd = cuda.cuMemExportToShareableHandle( + info.handle, + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + 0, + ) + check_cuda_result(result, "cuMemExportToShareableHandle") + return int(fd) + + def free(self, allocation_id: str) -> bool: + """Release physical memory for a single allocation. + + Args: + allocation_id: ID of allocation to free + + Returns: + True if allocation existed and was freed, False otherwise + """ + info = self._allocations.pop(allocation_id, None) + if info is None: + return False + self._release(info) + logger.debug(f"Freed allocation: {allocation_id}") + return True + + def clear_all(self) -> int: + """Release ALL allocations. + + Used by loaders before loading a new model, or during cleanup + when a writer aborts without committing. + + Returns: + Number of allocations cleared + """ + count = len(self._allocations) + for info in self._allocations.values(): + self._release(info) + self._allocations.clear() + logger.info(f"Cleared {count} allocations") + return count + + def get_allocation(self, allocation_id: str) -> AllocationInfo: + """Get allocation info. Raises AllocationNotFoundError if not found.""" + return self._get(allocation_id) + + def list_allocations(self, tag: Optional[str] = None) -> List[AllocationInfo]: + """List all allocations, optionally filtered by tag.""" + if tag is None: + return list(self._allocations.values()) + return [info for info in self._allocations.values() if info.tag == tag] diff --git a/lib/gpu_memory_service/server/rpc.py b/lib/gpu_memory_service/server/rpc.py new file mode 100644 index 00000000000..c21cc3ee210 --- /dev/null +++ b/lib/gpu_memory_service/server/rpc.py @@ -0,0 +1,427 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Async Allocation RPC Server - Single-threaded event loop with explicit state machine. + +State transitions are explicit and validated by the GlobalLockFSM class. +Operations are checked against state/mode permissions before execution. + +State Machine (see locking.py for full diagram): + EMPTY: No connections, not committed + RW: Writer connected (exclusive) + COMMITTED: No connections, committed (weights valid) + RO: Reader(s) connected (shared) +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import ClassVar, Optional + +from gpu_memory_service.common.protocol.messages import ( + AllocateRequest, + ClearAllRequest, + CommitRequest, + CommitResponse, + ErrorResponse, + ExportRequest, + FreeRequest, + GetAllocationRequest, + GetAllocationStateRequest, + GetLockStateRequest, + GetStateHashRequest, + HandshakeRequest, + HandshakeResponse, + ListAllocationsRequest, + MetadataDeleteRequest, + MetadataGetRequest, + MetadataListRequest, + MetadataPutRequest, +) +from gpu_memory_service.common.protocol.wire import recv_message, send_message +from gpu_memory_service.common.types import ( + GrantedLockType, + RequestedLockType, + ServerState, + StateEvent, +) + +from .handler import RequestHandler +from .locking import Connection, GlobalLockFSM + +logger = logging.getLogger(__name__) + + +class GMSRPCServer: + """GPU Memory Service RPC Server. + + Async single-threaded server using GlobalLockFSM for explicit state transitions + and operation validation. All state mutations happen through the state machine's + transition() method. + """ + + def __init__(self, socket_path: str, device: int = 0): + self.socket_path = socket_path + self.device = device + + # Request handler (business logic) + self._handler = RequestHandler(device) + + # State machine - handles all state transitions and permission checks + self._sm = GlobalLockFSM(on_rw_abort=self._handler.on_rw_abort) + self._waiting_writers: int = 0 + + # Async waiting for lock acquisition + self._condition = asyncio.Condition() + self._shutdown = False + + # Session ID generation + self._next_session_id: int = 0 + + # Server state + self._server: Optional[asyncio.Server] = None + self._running: bool = False + + logger.info(f"GMSRPCServer initialized: device={device}") + + # ==================== State Properties ==================== + + @property + def state(self) -> ServerState: + """Current server state (delegated to state machine).""" + return self._sm.state + + @property + def granularity(self) -> int: + return self._handler.granularity + + def is_ready(self) -> bool: + """Ready = committed and no RW connection.""" + return self._sm.committed and self._sm.rw_conn is None + + @property + def running(self) -> bool: + """Whether the server is running.""" + return self._running + + def _generate_session_id(self) -> str: + self._next_session_id += 1 + return f"session_{self._next_session_id}" + + # ==================== Connection Lifecycle ==================== + + async def _handle_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + """Handle a connection from accept to close.""" + session_id = self._generate_session_id() + conn: Optional[Connection] = None + + try: + conn = await self._do_handshake(reader, writer, session_id) + if conn is None: + return + await self._request_loop(conn) + except ConnectionResetError: + logger.debug(f"Connection reset: {session_id}") + except asyncio.CancelledError: + raise + except Exception: + logger.exception(f"Connection error: {session_id}") + finally: + await self._cleanup_connection(conn) + + async def _do_handshake( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + session_id: str, + ) -> Optional[Connection]: + """Perform handshake and acquire lock via state machine transition.""" + try: + # Server never receives FDs from clients, so no need for raw_sock + msg, _, recv_buffer = await recv_message(reader, bytearray()) + except Exception: + logger.exception("Handshake recv error") + return None + + if not isinstance(msg, HandshakeRequest): + await send_message(writer, ErrorResponse(error="Expected HandshakeRequest")) + writer.close() + return None + + # Acquire lock (blocks until available or timeout) + # Returns the actual granted mode (may differ from requested for rw_or_ro) + granted_mode = await self._acquire_lock(msg.lock_type, msg.timeout_ms) + if granted_mode is None: + await send_message( + writer, HandshakeResponse(success=False, committed=self._sm.committed) + ) + writer.close() + return None + + conn = Connection(reader, writer, granted_mode, session_id, recv_buffer) + + # State transition: connect + event = ( + StateEvent.RW_CONNECT + if granted_mode == GrantedLockType.RW + else StateEvent.RO_CONNECT + ) + self._sm.transition(event, conn) + + await send_message( + writer, + HandshakeResponse( + success=True, + committed=self._sm.committed, + granted_lock_type=granted_mode, + ), + ) + return conn + + async def _acquire_lock( + self, mode: RequestedLockType, timeout_ms: Optional[int] + ) -> Optional[GrantedLockType]: + """Wait until lock can be acquired (uses state machine predicates). + + Returns the granted lock type, or None if failed/timeout. + For rw_or_ro mode, returns RW if available immediately, else waits for RO. + """ + timeout = timeout_ms / 1000 if timeout_ms is not None else None + + if mode == RequestedLockType.RW: + self._waiting_writers += 1 + try: + async with self._condition: + try: + await asyncio.wait_for( + self._condition.wait_for( + lambda: self._shutdown or self._sm.can_acquire_rw() + ), + timeout=timeout, + ) + return None if self._shutdown else GrantedLockType.RW + except asyncio.TimeoutError: + return None + finally: + self._waiting_writers -= 1 + + elif mode == RequestedLockType.RO: + async with self._condition: + try: + await asyncio.wait_for( + self._condition.wait_for( + lambda: self._shutdown + or self._sm.can_acquire_ro(self._waiting_writers) + ), + timeout=timeout, + ) + return None if self._shutdown else GrantedLockType.RO + except asyncio.TimeoutError: + return None + + elif mode == RequestedLockType.RW_OR_RO: + # Auto mode: try RW if available immediately AND no committed weights, + # otherwise wait for RO (to import existing weights) + async with self._condition: + # Check if RW is available AND no committed weights exist + # If weights are already committed, prefer RO to import them + if self._sm.can_acquire_rw() and not self._sm.committed: + return GrantedLockType.RW + + # Either RW not available OR weights already committed - wait for RO + if self._sm.committed: + logger.info( + "RW_OR_RO: Weights already committed, preferring RO to import" + ) + else: + logger.info( + "RW_OR_RO: RW not available (another writer active), " + "falling back to RO" + ) + try: + await asyncio.wait_for( + self._condition.wait_for( + lambda: self._shutdown + or self._sm.can_acquire_ro(self._waiting_writers) + ), + timeout=timeout, + ) + return None if self._shutdown else GrantedLockType.RO + except asyncio.TimeoutError: + return None + return None + + async def _cleanup_connection(self, conn: Optional[Connection]) -> None: + """Clean up after connection closes via state machine transition.""" + if conn is None: + return + + # State transition: disconnect + if conn.mode == GrantedLockType.RW: + if self._sm.rw_conn is conn and not self._sm.committed: + # RW abort - state machine callback handles cleanup + self._sm.transition(StateEvent.RW_ABORT, conn) + elif self._sm.rw_conn is conn: + # Already committed, no transition needed (commit already did it) + pass + else: + if conn in self._sm.ro_conns: + self._sm.transition(StateEvent.RO_DISCONNECT, conn) + + await conn.close() + async with self._condition: + self._condition.notify_all() + + # ==================== Request Handling ==================== + + async def _request_loop(self, conn: Connection) -> None: + """Process requests until close or commit.""" + while self._running: + try: + # Server never receives FDs from clients, so no need for raw_socket + msg, _, conn.recv_buffer = await recv_message( + conn.reader, conn.recv_buffer + ) + except ConnectionResetError: + return + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Recv error") + return + + if msg is None: + continue + + try: + response, fd, should_close = await self._dispatch(conn, msg) + if response is not None: + try: + await send_message(conn.writer, response, fd) + finally: + if fd >= 0: + os.close(fd) + if should_close: + return + except Exception as e: + logger.exception("Request error") + await send_message(conn.writer, ErrorResponse(error=str(e))) + + # Dispatch table: message type -> handler method name + # Handlers take (msg) and return response. Special cases handled separately. + _HANDLERS: ClassVar[dict[type, str]] = { + AllocateRequest: "handle_allocate", + GetAllocationRequest: "handle_get_allocation", + ListAllocationsRequest: "handle_list_allocations", + FreeRequest: "handle_free", + MetadataPutRequest: "handle_metadata_put", + MetadataGetRequest: "handle_metadata_get", + MetadataDeleteRequest: "handle_metadata_delete", + MetadataListRequest: "handle_metadata_list", + } + + async def _dispatch(self, conn: Connection, msg) -> tuple[object, int, bool]: + """Dispatch request to handler. Returns (response, fd, should_close).""" + msg_type = type(msg) + self._sm.check_operation(msg_type, conn) + + # Special cases + if msg_type is CommitRequest: + return await self._handle_commit(conn) + + if msg_type is GetLockStateRequest: + return ( + self._handler.handle_get_lock_state( + self._sm.rw_conn is not None, + self._sm.ro_count, + self._waiting_writers, + self._sm.committed, + ), + -1, + False, + ) + + if msg_type is GetAllocationStateRequest: + return self._handler.handle_get_allocation_state(), -1, False + + if msg_type is ExportRequest: + response, fd = self._handler.handle_export(msg.allocation_id) + return response, fd, False + + if msg_type is ClearAllRequest: + return self._handler.handle_clear_all(), -1, False + + if msg_type is GetStateHashRequest: + return self._handler.handle_get_memory_layout_hash(), -1, False + + # Standard dispatch: handler takes msg, returns response + handler_name = self._HANDLERS.get(msg_type) + if handler_name: + handler = getattr(self._handler, handler_name) + return handler(msg), -1, False + + raise ValueError(f"Unknown request: {msg_type.__name__}") + + async def _handle_commit(self, conn: Connection) -> tuple[object, int, bool]: + """Handle commit via state machine transition - atomic with disconnect.""" + # Compute state hash before transitioning + self._handler.on_commit() + # State transition: commit + self._sm.transition(StateEvent.RW_COMMIT, conn) + + await send_message(conn.writer, CommitResponse(success=True)) + await conn.close() + + async with self._condition: + self._condition.notify_all() + + return None, -1, True + + # ==================== Server Lifecycle ==================== + + async def start(self) -> None: + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + + self._server = await asyncio.start_unix_server( + self._handle_connection, path=self.socket_path + ) + self._running = True + logger.info(f"Server started: {self.socket_path}") + + async def stop(self) -> None: + self._running = False + self._shutdown = True + async with self._condition: + self._condition.notify_all() + + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + + # Close connections (bypassing state machine - this is shutdown) + if self._sm.rw_conn: + await self._sm.rw_conn.close() + + for conn in list(self._sm.ro_conns): + await conn.close() + + self._handler.on_shutdown() + + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + + logger.info("Server stopped") + + async def serve_forever(self) -> None: + await self.start() + try: + while self._running: + await asyncio.sleep(1) + finally: + await self.stop() diff --git a/lib/gpu_memory_service/setup.py b/lib/gpu_memory_service/setup.py new file mode 100644 index 00000000000..d4b27a10218 --- /dev/null +++ b/lib/gpu_memory_service/setup.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Build script for GPU Memory Service with C++ extensions. + +This setup.py builds the C++ extensions as part of pip install. +The _allocator_ext extension only requires Python headers (no CUDA or PyTorch needed). + +Following the torch_memory_saver pattern of using pure setuptools for extension building. +""" + +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext + + +class BuildExtension(build_ext): + """Custom build extension for C++ modules.""" + + def build_extensions(self): + import os + + # Use CXX environment variable if set, otherwise default to g++ + cxx = os.environ.get("CXX", "g++") + self.compiler.set_executable("compiler_so", cxx) + self.compiler.set_executable("compiler_cxx", cxx) + self.compiler.set_executable("linker_so", f"{cxx} -shared") + + build_ext.build_extensions(self) + + +def _create_ext_modules(): + """Create extension modules for gpu_memory_service.""" + # Common compile arguments + extra_compile_args = ["-std=c++17", "-O3", "-fPIC"] + + # _allocator_ext: CUDAPluggableAllocator shim using only Python C API + # No CUDA or PyTorch dependency - just provides my_malloc/my_free that call Python callbacks + return [ + Extension( + name="gpu_memory_service.client.torch.extensions._allocator_ext", + sources=["client/torch/extensions/allocator.cpp"], + extra_compile_args=extra_compile_args, + ) + ] + + +setup( + name="gpu-memory-service", + version="0.8.0", + description="GPU Memory Service for Dynamo - CUDA VMM-based GPU memory allocation and sharing", + author="NVIDIA Inc.", + author_email="sw-dl-dynamo@nvidia.com", + license="Apache-2.0", + python_requires=">=3.10", + install_requires=[ + "msgpack>=1.0", + "uvloop>=0.21.0", + ], + extras_require={ + "test": [ + "pytest>=8.3.4", + "pytest-asyncio", + ], + }, + # Package directory mapping: the current directory IS the gpu_memory_service package + packages=[ + "gpu_memory_service", + "gpu_memory_service.common", + "gpu_memory_service.common.protocol", + "gpu_memory_service.server", + "gpu_memory_service.client", + "gpu_memory_service.client.torch", + "gpu_memory_service.client.torch.extensions", + ], + package_dir={ + "gpu_memory_service": ".", + "gpu_memory_service.common": "common", + "gpu_memory_service.common.protocol": "common/protocol", + "gpu_memory_service.server": "server", + "gpu_memory_service.client": "client", + "gpu_memory_service.client.torch": "client/torch", + "gpu_memory_service.client.torch.extensions": "client/torch/extensions", + }, + package_data={ + "gpu_memory_service.client.torch.extensions": ["*.cpp"], + }, + ext_modules=_create_ext_modules(), + cmdclass={"build_ext": BuildExtension}, +)