Skip to content

Commit 3e11240

Browse files
committed
address review: clarity
Signed-off-by: nicklucche <[email protected]>
1 parent 0e8f82a commit 3e11240

File tree

2 files changed

+49
-50
lines changed

2 files changed

+49
-50
lines changed

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,18 @@ async def send_request_to_service(client_info: dict, endpoint: str,
151151
Send a request to a service using a client from the pool.
152152
"""
153153
req_data = req_data.copy()
154-
req_data['do_remote_decode'] = True
154+
req_data['kv_transfer_params'] = {
155+
"do_remote_decode": True,
156+
"do_remote_prefill": False,
157+
"remote_engine_id": None,
158+
"remote_block_ids": None,
159+
"remote_host": None,
160+
"remote_port": None
161+
}
155162
req_data["stream"] = False
163+
req_data["max_tokens"] = 1
164+
if "stream_options" in req_data:
165+
del req_data["stream_options"]
156166
headers = {
157167
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
158168
"X-Request-Id": request_id
@@ -167,22 +177,14 @@ async def send_request_to_service(client_info: dict, endpoint: str,
167177

168178

169179
async def stream_service_response(client_info: dict, endpoint: str,
170-
req_data: dict, remote_block_ids: list[int],
171-
remote_engine_id: str, remote_host: str,
172-
remote_port: int, request_id: str):
180+
req_data: dict, request_id: str):
173181
"""
174182
Asynchronously stream response from a service using a client from the pool.
175183
"""
176184
headers = {
177185
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
178186
"X-Request-Id": request_id
179187
}
180-
req_data = req_data.copy()
181-
req_data['do_remote_prefill'] = True
182-
req_data["remote_block_ids"] = remote_block_ids
183-
req_data['remote_engine_id'] = remote_engine_id
184-
req_data["remote_host"] = remote_host
185-
req_data["remote_port"] = remote_port
186188

187189
async with client_info['client'].stream("POST",
188190
endpoint,
@@ -209,10 +211,9 @@ async def handle_completions(request: Request):
209211

210212
# Extract the needed fields
211213
response_json = response.json()
212-
remote_block_ids = response_json.get('remote_block_ids', [])
213-
remote_engine_id = response_json.get('remote_engine_id', '')
214-
remote_host = response_json.get('remote_host', '')
215-
remote_port = response_json.get('remote_port', 0)
214+
kv_transfer_params = response_json.get('kv_transfer_params', {})
215+
if kv_transfer_params:
216+
req_data["kv_transfer_params"] = kv_transfer_params
216217

217218
# Get the next decode client in round-robin fashion
218219
decode_client_info = get_next_client(request.app, 'decode')
@@ -221,15 +222,10 @@ async def handle_completions(request: Request):
221222

222223
# Stream response from decode service
223224
async def generate_stream():
224-
async for chunk in stream_service_response(
225-
decode_client_info,
226-
"/completions",
227-
req_data,
228-
remote_block_ids=remote_block_ids,
229-
remote_engine_id=remote_engine_id,
230-
remote_host=remote_host,
231-
remote_port=remote_port,
232-
request_id=request_id):
225+
async for chunk in stream_service_response(decode_client_info,
226+
"/completions",
227+
req_data,
228+
request_id=request_id):
233229
yield chunk
234230

235231
return StreamingResponse(generate_stream(),

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
318318
logger.info("Initializing NIXL wrapper")
319319
logger.info("Initializing NIXL worker %s", engine_id)
320320

321+
# Config.
322+
self.vllm_config = vllm_config
323+
self.block_size = vllm_config.cache_config.block_size
324+
321325
# Agent.
322326
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
323327
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
@@ -378,7 +382,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
378382
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
379383
# With heterogeneous TP, P must wait for all assigned D TP workers to
380384
# finish reading before safely freeing the blocks.
381-
self.consumer_notification_counts_by_req = defaultdict(int)
385+
self.consumer_notification_counts_by_req: dict[str,
386+
int] = defaultdict(int)
382387

383388
@staticmethod
384389
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
@@ -424,41 +429,39 @@ def _nixl_handshake(self, host: str, port: int):
424429
# a hack to keep us moving. We will switch when moving to etcd
425430
# or where we have a single ZMQ socket in the scheduler.
426431

427-
def handshake(sock, rank: int) -> NixlAgentMetadata:
432+
def handshake(path: str, rank: int) -> NixlAgentMetadata:
428433
# Send query for the request.
429-
sock.send(GET_META_MSG)
430-
metadata_bytes = sock.recv()
431-
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
432-
metadata = decoder.decode(metadata_bytes)
433-
got_metadata_time = time.perf_counter()
434-
435-
# Register Remote agent.
436-
self.add_remote_agent(metadata, rank)
437-
setup_agent_time = time.perf_counter()
438-
439-
logger.debug("NIXL handshake: get metadata took: %s",
440-
got_metadata_time - start_time)
441-
logger.debug("NIXL handshake: add agent took: %s",
442-
setup_agent_time - got_metadata_time)
443-
return metadata
434+
with zmq_ctx(zmq.REQ, path) as sock:
435+
sock.send(GET_META_MSG)
436+
metadata_bytes = sock.recv()
437+
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
438+
metadata = decoder.decode(metadata_bytes)
439+
got_metadata_time = time.perf_counter()
440+
441+
# Register Remote agent.
442+
self.add_remote_agent(metadata, rank)
443+
setup_agent_time = time.perf_counter()
444+
445+
logger.debug("NIXL handshake: get metadata took: %s",
446+
got_metadata_time - start_time)
447+
logger.debug("NIXL handshake: add agent took: %s",
448+
setup_agent_time - got_metadata_time)
449+
return metadata
444450

445451
# Handshake with remote agent-rank0 first to get the tp_size of remote
446452
path = f"tcp://{host}:{port}"
447453
logger.debug("Querying master rank metadata on path: %s", path)
448-
with zmq_ctx(zmq.REQ, path) as sock:
449-
metadata = handshake(sock, 0)
454+
metadata = handshake(path, 0)
450455

451456
# Handshake only with the other TP remote the current local rank will
452457
# pull from. With homogeneous TP it happens to be the same rank_i.
453-
d_workers_per_p_worker = self._tp_size[
454-
self.engine_id] // metadata.tp_size
455-
p_remote_rank = self.rank // d_workers_per_p_worker
458+
tp_rate = self._tp_size[self.engine_id] // metadata.tp_size
459+
p_remote_rank = self.rank // tp_rate
456460
if p_remote_rank > 0:
457461
path = f"tcp://{host}:{port + p_remote_rank}"
458462
logger.debug("Querying metadata on path: %s at remote rank %s",
459463
path, p_remote_rank)
460-
with zmq_ctx(zmq.REQ, path) as sock:
461-
metadata = handshake(sock, p_remote_rank)
464+
_ = handshake(path, p_remote_rank)
462465

463466
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
464467
"""Register the KV Cache data in nixl."""
@@ -473,17 +476,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
473476
self.num_blocks = first_kv_cache.shape[0]
474477
block_rank = 2 # [block_size, latent_dim]
475478
block_shape = first_kv_cache.shape[-block_rank:]
476-
self.block_size, kv_latent_dim = block_shape
479+
block_size, kv_latent_dim = block_shape
477480
self.slot_size_bytes = kv_elem_size * kv_latent_dim
478481
else:
479482
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
480483
self.num_blocks = first_kv_cache.shape[1]
481484
block_rank = 3 # [block_size, kv_heads, head_dim]
482485
block_shape = first_kv_cache.shape[-block_rank:]
483-
self.block_size, n_kv_heads, head_dim = block_shape
486+
block_size, n_kv_heads, head_dim = block_shape
484487
# head size in bytes.
485488
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
486-
489+
assert block_size == self.block_size
487490
# TODO(tms): self.block_len needs to be per-layer for sliding window,
488491
# hybrid attn, etc
489492
# block size in bytes

0 commit comments

Comments
 (0)