Skip to content

Commit d0d138b

Browse files
chenxi-yangChenxi Yang
andauthored
[Nixl][P/D] Add cuda2cpu support (HD->DH transfer) (vllm-project#24690)
Signed-off-by: Chenxi Yang <[email protected]> Co-authored-by: Chenxi Yang <[email protected]>
1 parent 4322723 commit d0d138b

File tree

6 files changed

+96
-15
lines changed

6 files changed

+96
-15
lines changed

tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
#!/bin/bash
22
set -xe
33

4+
# Parse command line arguments
5+
KV_BUFFER_DEVICE="cuda" # Default to cuda
6+
while [[ $# -gt 0 ]]; do
7+
case $1 in
8+
--kv_buffer_device)
9+
KV_BUFFER_DEVICE="$2"
10+
shift 2
11+
;;
12+
*)
13+
echo "Unknown option $1"
14+
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
15+
exit 1
16+
;;
17+
esac
18+
done
19+
20+
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
21+
22+
# Build the kv-transfer-config once
23+
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
24+
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
25+
else
26+
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
27+
fi
28+
429
# Models to run
530
MODELS=(
631
"Qwen/Qwen3-0.6B"
@@ -79,7 +104,7 @@ run_tests_for_model() {
79104

80105
# Calculate port number (base port + instance number)
81106
PORT=$((8100 + i))
82-
# Calculate side channel port. Avoid clash with with TP workers.
107+
# Calculate side channel port. Avoid clash with with TP workers.
83108
SIDE_CHANNEL_PORT=$((5559 + i))
84109

85110
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
@@ -93,7 +118,7 @@ run_tests_for_model() {
93118
--enforce-eager \
94119
--gpu-memory-utilization 0.2 \
95120
--tensor-parallel-size $PREFILLER_TP_SIZE \
96-
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
121+
--kv-transfer-config '$KV_CONFIG'"
97122

98123
if [ -n "$model_args" ]; then
99124
FULL_CMD="$BASE_CMD $model_args"
@@ -128,7 +153,7 @@ run_tests_for_model() {
128153
--enforce-eager \
129154
--gpu-memory-utilization 0.2 \
130155
--tensor-parallel-size $DECODER_TP_SIZE \
131-
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
156+
--kv-transfer-config '$KV_CONFIG'"
132157

133158
if [ -n "$model_args" ]; then
134159
FULL_CMD="$BASE_CMD $model_args"

tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh

100644100755
Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,33 @@
11
#!/bin/bash
22
set -xe
33

4+
# Parse command line arguments
5+
KV_BUFFER_DEVICE="cuda" # Default to cuda
6+
PREFILL_GPU_ID=4 # Default GPU IDs
7+
DECODE_GPU_ID=5
8+
while [[ $# -gt 0 ]]; do
9+
case $1 in
10+
--kv_buffer_device)
11+
KV_BUFFER_DEVICE="$2"
12+
shift 2
13+
;;
14+
*)
15+
echo "Unknown option $1"
16+
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
17+
exit 1
18+
;;
19+
esac
20+
done
21+
22+
echo "Running edge case tests with kv_buffer_device=$KV_BUFFER_DEVICE (GPUs: $PREFILL_GPU_ID, $DECODE_GPU_ID)"
23+
24+
# Build the kv-transfer-config once
25+
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
26+
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
27+
else
28+
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
29+
fi
30+
431
# Models to run
532
MODELS=(
633
"Qwen/Qwen3-0.6B"
@@ -50,15 +77,15 @@ run_tests_for_model() {
5077

5178
# Get model-specific arguments
5279
local model_args=$(get_model_args "$model_name")
53-
80+
5481
# Start prefill instance
5582
PREFILL_PORT=8001
5683

57-
BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
84+
BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
5885
--port $PREFILL_PORT \
5986
--enforce-eager \
6087
--gpu-memory-utilization 0.2 \
61-
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
88+
--kv-transfer-config '$KV_CONFIG'"
6289

6390
if [ -n "$model_args" ]; then
6491
FULL_CMD="$BASE_CMD $model_args"
@@ -72,11 +99,11 @@ run_tests_for_model() {
7299
DECODE_PORT=8002
73100

74101
# Build the command with or without model-specific args
75-
BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
102+
BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
76103
--port $DECODE_PORT \
77104
--enforce-eager \
78105
--gpu-memory-utilization 0.2 \
79-
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
106+
--kv-transfer-config '$KV_CONFIG'"
80107

81108
if [ -n "$model_args" ]; then
82109
FULL_CMD="$BASE_CMD $model_args"

vllm/config/kv_transfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class KVTransferConfig:
2828
"""The engine id for KV transfers."""
2929

3030
kv_buffer_device: Optional[str] = "cuda"
31-
"""The device used by kv connector to buffer the KV cache.
32-
Currently only support 'cuda'."""
31+
"""The device used by kv connector to buffer the KV cache. Choices are
32+
'cuda' and 'cpu'."""
3333

3434
kv_buffer_size: float = 1e9
3535
"""The buffer size for TorchDistributedConnector. Measured in number of

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@
6767
# Supported platforms and types of kv transfer buffer.
6868
# {device: tuple of supported kv buffer types}
6969
_NIXL_SUPPORTED_DEVICE = {
70-
"cuda": ("cuda", ),
70+
"cuda": (
71+
"cuda",
72+
"cpu",
73+
),
7174
"tpu": ("cpu", ),
7275
"xpu": ("cpu", ),
7376
}
@@ -701,6 +704,9 @@ def initialize_host_xfer_buffer(
701704

702705
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
703706
"""Assign copy (d2h, h2d) operations when host buffer is used."""
707+
# Set a no-op if the host buffer is not cpu.
708+
if self.kv_buffer_device != "cpu":
709+
return
704710
assert self.use_host_buffer
705711
self.copy_blocks = copy_operation
706712

vllm/platforms/cuda.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,30 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
500500
"You can use float16 instead by explicitly setting the "
501501
"`dtype` flag in CLI, for example: --dtype=half.")
502502

503+
@classmethod
504+
def insert_blocks_to_device(
505+
cls,
506+
src_cache: torch.Tensor,
507+
dst_cache: torch.Tensor,
508+
src_block_indices: torch.Tensor,
509+
dst_block_indices: torch.Tensor,
510+
) -> None:
511+
"""Copy blocks from src_cache to dst_cache on GPU."""
512+
_src_cache = src_cache[:, src_block_indices]
513+
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
514+
515+
@classmethod
516+
def swap_out_blocks_to_host(
517+
cls,
518+
src_cache: torch.Tensor,
519+
dst_cache: torch.Tensor,
520+
src_block_indices: torch.Tensor,
521+
dst_block_indices: torch.Tensor,
522+
) -> None:
523+
"""Copy blocks from GPU to host (CPU)."""
524+
_src_cache = src_cache[:, src_block_indices]
525+
dst_cache[:, dst_block_indices] = _src_cache.cpu()
526+
503527
@classmethod
504528
def support_hybrid_kv_cache(cls) -> bool:
505529
return True

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4059,10 +4059,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
40594059
self.drafter.validate_same_kv_cache_group(kv_cache_config)
40604060

40614061
if has_kv_transfer_group():
4062-
get_kv_transfer_group().register_kv_caches(kv_caches)
4063-
if self.device.type == 'xpu':
4064-
get_kv_transfer_group().set_host_xfer_buffer_ops(
4065-
copy_kv_blocks)
4062+
kv_transfer_group = get_kv_transfer_group()
4063+
kv_transfer_group.register_kv_caches(kv_caches)
4064+
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
40664065

40674066
if self.dcp_world_size > 1:
40684067
layer_names = self.attn_groups[0][0].layer_names

0 commit comments

Comments
 (0)