Skip to content

Commit e1404f3

Browse files
committed
modify kv buffer for hpu
1 parent c6ac725 commit e1404f3

File tree

5 files changed

+200
-31
lines changed

5 files changed

+200
-31
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/bin/bash
2+
set -xe
3+
4+
# Hosts / ports
5+
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
6+
PREFILL_PORT=${PREFILL_PORT:-8100}
7+
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
8+
DECODE_HOST=${DECODE_HOST:-"localhost"}
9+
DECODE_PORT=${DECODE_PORT:-8200}
10+
PROXY_HOST=${PROXY_HOST:-"localhost"}
11+
PROXY_PORT=${PROXY_PORT:-8192}
12+
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
13+
BASELINE_PORT=${BASELINE_PORT:-9290}
14+
15+
16+
# Model to run.
17+
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
18+
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
19+
BLOCK_SIZE=${BLOCK_SIZE:-32}
20+
21+
22+
# execution env
23+
GIT_ROOT=$(git rev-parse --show-toplevel)
24+
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
25+
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
26+
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
27+
28+
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
29+
30+
# Trap the SIGINT signal (triggered by Ctrl+C)
31+
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
32+
33+
34+
# Waits for vLLM server to start.
35+
wait_for_server() {
36+
local host=$1
37+
local port=$2
38+
timeout 1200 bash -c "
39+
until curl -s ${host}:${port}/v1/completions > /dev/null; do
40+
sleep 1
41+
done" && return 0 || return 1
42+
}
43+
44+
# Cleanup function
45+
cleanup() {
46+
echo "Caught Ctrl+C, cleaning up..."
47+
# Cleanup commands
48+
pgrep python | xargs kill -9 || true
49+
# pkill -f python || true
50+
echo "Cleanup complete. Exiting."
51+
}
52+
53+
launch_baseline() {
54+
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
55+
VLLM_LOGGING_LEVEL=DEBUG \
56+
VLLM_USE_V1=1 \
57+
PJRT_DEVICE=TPU \
58+
VLLM_WORKER_MULTIPROC_METHOD=spawn \
59+
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
60+
--host ${BASELINE_HOST} \
61+
--port ${BASELINE_PORT} \
62+
--max-model-len ${MAX_MODEL_LEN}\
63+
--seed 42 \
64+
--block-size ${BLOCK_SIZE} \
65+
--gpu-memory-utilization 0.5 \
66+
--enforce-eager"
67+
echo ${BASELINE_BASE_CMD}
68+
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
69+
}
70+
71+
launch_pd() {
72+
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
73+
UCX_TLS=tcp \
74+
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
75+
VLLM_LOGGING_LEVEL=DEBUG \
76+
VLLM_USE_V1=1 \
77+
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
78+
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
79+
PJRT_DEVICE=TPU \
80+
VLLM_WORKER_MULTIPROC_METHOD=spawn \
81+
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
82+
--host ${PREFILL_HOST} \
83+
--port ${PREFILL_PORT} \
84+
--max-model-len ${MAX_MODEL_LEN}\
85+
--seed 42 \
86+
--block-size ${BLOCK_SIZE} \
87+
--enforce-eager \
88+
--gpu-memory-utilization 0.5 \
89+
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
90+
91+
92+
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
93+
UCX_TLS=tcp \
94+
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
95+
VLLM_LOGGING_LEVEL=DEBUG \
96+
VLLM_USE_V1=1 \
97+
PJRT_DEVICE=TPU \
98+
VLLM_WORKER_MULTIPROC_METHOD=spawn \
99+
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
100+
--host ${DECODE_HOST} \
101+
--port ${DECODE_PORT} \
102+
--max-model-len ${MAX_MODEL_LEN}\
103+
--seed 42 \
104+
--block-size ${BLOCK_SIZE} \
105+
--enforce-eager \
106+
--gpu-memory-utilization 0.5 \
107+
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
108+
109+
echo ${PREFILL_BASE_CMD}
110+
echo ${DECODE_BASE_CMD}
111+
sleep 2
112+
113+
# execute on hosts
114+
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
115+
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
116+
sleep 1
117+
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
118+
sleep 1
119+
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
120+
sleep 1
121+
}
122+
123+
launch_pd_proxy(){
124+
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
125+
python3 ${EXP_ROOT}/toy_proxy_server.py \
126+
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
127+
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
128+
--host=${PROXY_HOST} --port ${PROXY_PORT}"
129+
echo ${PROXY_BASE_CMD}
130+
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
131+
}
132+
133+
run_tests(){
134+
local service_url=$1
135+
local mode=$2
136+
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
137+
}
138+
139+
140+
# run non-disagg. baseline & save outputs
141+
launch_baseline
142+
sleep 2
143+
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
144+
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
145+
cleanup
146+
sleep 10
147+
148+
149+
# run disagg. & do exact-match with the outputs from baseline
150+
launch_pd
151+
launch_pd_proxy
152+
sleep 10
153+
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
154+
echo "-----P/D success----"
155+
156+
rm ${OUTPUT_FILE}
157+
cleanup
158+
159+
exit 0

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import argparse
55
import itertools
6+
import logging
67
import os
78
import uuid
89
from contextlib import asynccontextmanager
@@ -11,9 +12,8 @@
1112
from fastapi import FastAPI, Request
1213
from fastapi.responses import StreamingResponse
1314

14-
from vllm.logger import init_logger
15-
16-
logger = init_logger(__name__)
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.DEBUG)
1717

1818

1919
@asynccontextmanager
@@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str,
162162
}
163163
req_data["stream"] = False
164164
req_data["max_tokens"] = 1
165+
if "max_completion_tokens" in req_data:
166+
req_data["max_completion_tokens"] = 1
165167
if "stream_options" in req_data:
166168
del req_data["stream_options"]
167169
headers = {
@@ -196,8 +198,7 @@ async def stream_service_response(client_info: dict, endpoint: str,
196198
yield chunk
197199

198200

199-
@app.post("/v1/completions")
200-
async def handle_completions(request: Request):
201+
async def _handle_completions(api: str, request: Request):
201202
try:
202203
req_data = await request.json()
203204
request_id = str(uuid.uuid4())
@@ -206,9 +207,8 @@ async def handle_completions(request: Request):
206207
prefill_client_info = get_next_client(request.app, 'prefill')
207208

208209
# Send request to prefill service
209-
response = await send_request_to_service(prefill_client_info,
210-
"/completions", req_data,
211-
request_id)
210+
response = await send_request_to_service(prefill_client_info, api,
211+
req_data, request_id)
212212

213213
# Extract the needed fields
214214
response_json = response.json()
@@ -224,7 +224,7 @@ async def handle_completions(request: Request):
224224
# Stream response from decode service
225225
async def generate_stream():
226226
async for chunk in stream_service_response(decode_client_info,
227-
"/completions",
227+
api,
228228
req_data,
229229
request_id=request_id):
230230
yield chunk
@@ -237,12 +237,22 @@ async def generate_stream():
237237
import traceback
238238
exc_info = sys.exc_info()
239239
print("Error occurred in disagg prefill proxy server"
240-
" - completions endpoint")
240+
f" - {api} endpoint")
241241
print(e)
242242
print("".join(traceback.format_exception(*exc_info)))
243243
raise
244244

245245

246+
@app.post("/v1/completions")
247+
async def handle_completions(request: Request):
248+
return await _handle_completions("/completions", request)
249+
250+
251+
@app.post("/v1/chat/completions")
252+
async def handle_chat_completions(request: Request):
253+
return await _handle_completions("/chat/completions", request)
254+
255+
246256
@app.get("/healthcheck")
247257
async def healthcheck():
248258
"""Simple endpoint to check if the server is running."""

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
_NIXL_SUPPORTED_XPUS = {
6060
"cuda": ("cuda", ),
6161
"tpu": ("cpu", ),
62+
"hpu": ("cpu", )
6263
}
6364

6465

@@ -467,7 +468,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
467468
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[
468469
self.device_type]:
469470
raise RuntimeError(
470-
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
471+
f"kvconf{vllm_config.kv_transfer_config} {self.device_type} with {self.kv_buffer_device} kv_buffer "
471472
"is not supported.")
472473
self.device_kv_caches: dict[str, torch.Tensor] = {}
473474

@@ -689,9 +690,11 @@ def request_ready(_f: Future[Any], entry=(req_id, meta)):
689690

690691
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
691692
"""Register the KV Cache data in nixl."""
692-
693693
_, first_kv_cache = next(iter(kv_caches.items()))
694-
kv_elem_size = first_kv_cache.element_size()
694+
if self.device_type == "hpu":
695+
kv_elem_size = first_kv_cache[0].dtype.itemsize
696+
else:
697+
kv_elem_size = first_kv_cache.element_size()
695698

696699
if self.use_host_buffer:
697700
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
@@ -734,36 +737,31 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
734737
block_size, kv_latent_dim = block_shape
735738
self.slot_size_bytes = kv_elem_size * kv_latent_dim
736739
else:
737-
# [2 (k and v), num_blocks, ...]
738-
#if self._use_flashinfer:
739-
# # FlashInfer swaps 2<->num_blocks dimensions.
740-
# self.num_blocks = first_kv_cache.shape[0]
741-
# block_rank = 4 # [2, block_size, kv_heads, head_dim]
742-
#else:
743-
# self.num_blocks = first_kv_cache.shape[1]
744-
# block_rank = 3 # [block_size, kv_heads, head_dim]
745-
#block_shape = first_kv_cache.shape[-block_rank:]
746-
#block_size, n_kv_heads, head_dim = block_shape[-3:]
747-
748-
# TODO see if below is necessary, else uncomment above
749740
# [2 (k and v), num_blocks, ...]
750741
if self._use_flashinfer:
751742
# FlashInfer swaps 2<->num_blocks dimensions.
752743
self.num_blocks = first_kv_cache.shape[0]
753744
block_rank = 4 # [2, block_size, kv_heads, head_dim]
754745
else:
755-
# habana kv_cache: [2, num_blocks*block_size, kv_heads, head_dim]
756-
self.num_blocks = first_kv_cache.shape[1] // self.block_size
746+
self.num_blocks = first_kv_cache.shape[1]
757747
block_rank = 3 # [block_size, kv_heads, head_dim]
758748
block_shape = first_kv_cache.shape[-block_rank:]
759-
block_shape = list(block_shape)
760-
block_shape[0] = block_shape[0] // self.num_blocks
761-
block_shape = torch.Size(block_shape)
762749
block_size, n_kv_heads, head_dim = block_shape[-3:]
763750

764751
# head size in bytes.
765752
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
766753
assert block_size == self.block_size
754+
elif self.device_type == "hpu":
755+
# habana kv_cache: [2, num_blocks*block_size, kv_heads, head_dim]
756+
self.num_blocks = first_kv_cache.shape[1] // self.block_size
757+
block_rank = 3 # [block_size, kv_heads, head_dim]
758+
block_shape = first_kv_cache.shape[-block_rank:]
759+
block_shape = list(block_shape)
760+
block_shape[0] = block_shape[0] // self.num_blocks
761+
block_shape = torch.Size(block_shape)
762+
block_size, n_kv_heads, head_dim = block_shape[-3:]
763+
# head size in bytes.
764+
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
767765
else:
768766
raise RuntimeError(
769767
f"{self.device_type} ({self.backend_name}) is not supported.")

vllm/v1/core/kv_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
self,
6969
kv_cache_config: KVCacheConfig,
7070
max_model_len: int,
71-
enable_caching: bool = True,
71+
enable_caching: bool = False,
7272
caching_hash_algo: str = "builtin",
7373
use_eagle: bool = False,
7474
log_stats: bool = False,

vllm/v1/worker/hpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,6 +2430,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
24302430
self._PAD_SLOT_ID = num_blocks * self.block_size
24312431

24322432
if has_kv_transfer_group():
2433+
#import remote_pdb; remote_pdb.set_trace()
2434+
kv_caches = { layer: torch.stack((tup[0], tup[1])) for layer,tup in kv_caches.items()}
24332435
get_kv_transfer_group().register_kv_caches(kv_caches)
24342436

24352437
htorch.hpu.synchronize()

0 commit comments

Comments
 (0)