Skip to content

Commit 2e7f976

Browse files
GuanLuopaulpak58
authored andcommitted
fix: NIXL connector transfers partial block to pass full multi-modal context (vllm-project#21074)
Signed-off-by: GuanLuo <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent 4fb8402 commit 2e7f976

File tree

4 files changed

+130
-41
lines changed

4 files changed

+130
-41
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def test_prompt_less_than_block_size():
173173
"""
174174
Test that we can handle case where prompt is < block.
175175
176-
In this case, the P worker will send empty remote_block_ids.
177-
The D worker should not schedule an async read in this case,
178-
since there is nothing to pull.
176+
In this case, the P worker will still send remote_block_ids of the
177+
partial block. The D worker should schedule an async read
178+
in this case.
179179
"""
180180
vllm_config = create_vllm_config()
181181
scheduler = create_scheduler(vllm_config)
@@ -184,22 +184,20 @@ def test_prompt_less_than_block_size():
184184
BLOCK_SIZE = vllm_config.cache_config.block_size
185185
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
186186

187-
# Request will have 0 remote blocks.
187+
# Request will have 1 partial remote block.
188188
request = create_request(request_id=1,
189189
num_tokens=NUM_TOKENS,
190190
do_remote_prefill=True,
191-
num_remote_blocks=0)
191+
num_remote_blocks=1)
192192
scheduler.add_request(request)
193193
scheduler_output = scheduler.schedule()
194194

195-
# This request should not have to read async.
195+
# This request will read async.
196196
kv_connector_metadata = scheduler_output.kv_connector_metadata
197197
assert kv_connector_metadata is not None
198198
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
199-
assert len(kv_connector_metadata.reqs_to_recv) == 0
200-
201-
# This request should be scheduled regularly.
202-
assert len(scheduler_output.scheduled_new_reqs) == 1
199+
assert len(kv_connector_metadata.reqs_to_recv) == 1
200+
assert len(scheduler_output.scheduled_new_reqs) == 0
203201

204202

205203
class FakeNixlConnectorWorker(NixlConnectorWorker):

tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,18 @@ def test_short_prompt_lifecycle():
121121
model_runner_output = create_model_runner_output(reqs=[request])
122122

123123
# (1c): update_from_output()
124-
# Since tokens < block_size, there will be no kv xfer.
125-
# So this should be cleaned up immediately.
126-
_ = scheduler.update_from_output(scheduler_output, model_runner_output)
124+
# Even though tokens < block_size, there will be kv xfer for partial block.
125+
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
126+
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
127+
128+
assert (len(kv_transfer_params["remote_block_ids"]) == 1)
127129

128130
# Confirm we do not have any memory leaks after req lifecycle.
129-
# We need one more call to schedule() to clear data for persistent batch.
130-
_ = scheduler.schedule()
131+
# We need to mark sending finish to clear data for persistent batch.
132+
scheduler_output = scheduler.schedule()
133+
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
134+
model_runner_output.finished_sending = [request.request_id]
135+
scheduler.update_from_output(scheduler_output, model_runner_output)
131136
assert_scheduler_empty(scheduler)
132137

133138

@@ -169,16 +174,16 @@ def test_prefix_cache_lifecycle():
169174
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
170175
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
171176

172-
# Ensure we send all block ids, even if there is a cache hit.
177+
# Ensure we send all block ids, including the partial blocks,
178+
# even if there is a cache hit.
173179
assert (len(
174-
kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS)
180+
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
181+
1))
175182

176183
# STEP (2): Ensure it is freed.
177184
scheduler_output = scheduler.schedule()
178-
scheduler.schedule()
179185
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
180186
model_runner_output.kv_connector_output = KVConnectorOutput(
181187
finished_sending=[request_remote.request_id])
182188
scheduler.update_from_output(scheduler_output, model_runner_output)
183-
_ = scheduler.schedule()
184189
assert_scheduler_empty(scheduler)

tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def test_cannot_schedule_after_recv():
362362
BLOCK_SIZE = vllm_config.cache_config.block_size
363363
# Prompt will use 2 blocks + 1 block after we schedule.
364364
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
365-
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
365+
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
366366

367367
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
368368
request_remote = create_request(request_id=2,
@@ -393,30 +393,124 @@ def test_cannot_schedule_after_recv():
393393
assert len(scheduler.running) == 1
394394
assert len(scheduler.waiting) == 1
395395

396-
# Step 4: try to schedule, not enough blocks.
396+
# Step 4: try to schedule, remote request is put to running list
397+
# because the transfer is completed.
398+
scheduler_output = scheduler.schedule()
399+
model_runner_output = create_model_runner_output(
400+
reqs=[request_normal, request_remote])
401+
scheduler.update_from_output(scheduler_output, model_runner_output)
402+
assert len(scheduler.running) == 2
403+
assert len(scheduler.waiting) == 0
404+
405+
# Step 5: Remote request will be put back to waiting list
406+
# because it needs new block to hold generated token.
397407
scheduler_output = scheduler.schedule()
398408
model_runner_output = create_model_runner_output(reqs=[request_normal])
399409
scheduler.update_from_output(scheduler_output, model_runner_output)
400410
assert len(scheduler.running) == 1
401411
assert len(scheduler.waiting) == 1
402412

403-
# Step 5: finish the request, free it.
413+
# Step 6: finish the request, free it.
404414
scheduler_output = scheduler.schedule()
405415
model_runner_output = create_model_runner_output(reqs=[request_normal],
406416
use_eos=True)
407417
scheduler.update_from_output(scheduler_output, model_runner_output)
408418
assert len(scheduler.running) == 0
409419
assert len(scheduler.waiting) == 1
410420

411-
# Step 6: now we can schedule (with 2 blocks computed).
421+
# Step 7: now we can schedule (with 2 blocks computed),
422+
# request is retrieved from preempted list.
412423
scheduler_output = scheduler.schedule()
413424
model_runner_output = create_model_runner_output(reqs=[request_remote])
414-
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
425+
assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ==
415426
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
416427
scheduler.update_from_output(scheduler_output, model_runner_output)
417428
assert len(scheduler.running) == 1
418429
assert len(scheduler.waiting) == 0
419430

431+
# Step 8: free everything.
432+
scheduler_output = scheduler.schedule()
433+
model_runner_output = create_model_runner_output(reqs=[request_remote],
434+
use_eos=True)
435+
scheduler.update_from_output(scheduler_output, model_runner_output)
436+
_ = scheduler.schedule()
437+
assert_scheduler_empty(scheduler)
438+
439+
440+
def test_cannot_recv():
441+
"""
442+
Test that we can handle no schedule KV block transfer due to not
443+
enough remaining KV blocks.
444+
"""
445+
446+
# NOTE: the KVCacheManager will use 1 null block.
447+
# So there are 5 total working blocks.
448+
TOTAL_NUM_BLOCKS = 6
449+
vllm_config = create_vllm_config()
450+
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
451+
452+
# Prime the KVCache.
453+
NUM_PROMPT_BLOCKS = 2
454+
BLOCK_SIZE = vllm_config.cache_config.block_size
455+
# Prompt will use 2 blocks + 1 block after we schedule.
456+
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
457+
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
458+
459+
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
460+
request_remote = create_request(request_id=2,
461+
num_tokens=NUM_TOKENS_REMOTE,
462+
do_remote_prefill=True)
463+
464+
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
465+
scheduler.add_request(request_normal)
466+
scheduler_output = scheduler.schedule()
467+
model_runner_output = create_model_runner_output(reqs=[request_normal])
468+
scheduler.update_from_output(scheduler_output, model_runner_output)
469+
assert len(scheduler.running) == 1
470+
assert len(scheduler.waiting) == 0
471+
472+
# Step 2: 3 blocks are in use,
473+
# need 3 new for remote blocks but only 2 are available.
474+
scheduler.add_request(request_remote)
475+
scheduler_output = scheduler.schedule()
476+
model_runner_output = create_model_runner_output(reqs=[request_normal])
477+
scheduler.update_from_output(scheduler_output, model_runner_output)
478+
assert len(scheduler.running) == 1
479+
assert len(scheduler.waiting) == 1
480+
# Should not have KV transfer in progress.
481+
assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS)
482+
483+
# Step 3: finish the request, free it.
484+
scheduler_output = scheduler.schedule()
485+
model_runner_output = create_model_runner_output(reqs=[request_normal],
486+
use_eos=True)
487+
scheduler.update_from_output(scheduler_output, model_runner_output)
488+
assert len(scheduler.running) == 0
489+
assert len(scheduler.waiting) == 1
490+
491+
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
492+
scheduler_output = scheduler.schedule()
493+
model_runner_output = create_model_runner_output(reqs=[])
494+
scheduler.update_from_output(scheduler_output, model_runner_output)
495+
assert len(scheduler.running) == 0
496+
assert len(scheduler.waiting) == 1
497+
assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
498+
499+
# Step 5: finish recving (5 blocks in use)
500+
scheduler_output = scheduler.schedule()
501+
model_runner_output = create_model_runner_output(
502+
reqs=[], finished_recving=[request_remote.request_id])
503+
scheduler.update_from_output(scheduler_output, model_runner_output)
504+
assert len(scheduler.running) == 0
505+
assert len(scheduler.waiting) == 1
506+
507+
# Step 6: schedule remote request
508+
scheduler_output = scheduler.schedule()
509+
model_runner_output = create_model_runner_output(reqs=[request_remote])
510+
scheduler.update_from_output(scheduler_output, model_runner_output)
511+
assert len(scheduler.running) == 1
512+
assert len(scheduler.waiting) == 0
513+
420514
# Step 7: free everything.
421515
scheduler_output = scheduler.schedule()
422516
model_runner_output = create_model_runner_output(reqs=[request_remote],

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from vllm.forward_context import ForwardContext
3030
from vllm.logger import init_logger
3131
from vllm.platforms import _Backend, current_platform
32-
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
32+
from vllm.utils import make_zmq_path, make_zmq_socket
3333
from vllm.v1.core.sched.output import SchedulerOutput
3434
from vllm.v1.request import RequestStatus
3535

@@ -275,10 +275,7 @@ def get_num_new_matched_tokens(
275275

276276
if params is not None and params.get("do_remote_prefill"):
277277
# Remote prefill: get all prompt blocks from remote.
278-
assert num_computed_tokens % self.block_size == 0
279-
rounded_num_prompt_tokens = round_down(
280-
len(request.prompt_token_ids), self.block_size)
281-
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
278+
count = len(request.prompt_token_ids) - num_computed_tokens
282279
if count > 0:
283280
return count, True
284281

@@ -301,18 +298,16 @@ def update_state_after_alloc(self, request: "Request",
301298
# NOTE: when accelerator is not directly supported by Nixl,
302299
# prefilled blocks need to be saved to host memory before transfer.
303300

304-
# figure out full computed blocks to save
301+
# save all blocks
305302
block_ids = blocks.get_block_ids()[0]
306-
all_full = request.num_tokens % self.block_size == 0
307-
full_block_ids = (block_ids if all_full else block_ids[:-1])
308303
# TODO: skip the blocks that are already in the host xfer buffer.
309304
# Currently, the host xfer buffer block is 1-to-1 mapped to device
310305
# kv blocks, so host blocks won't be flushed as long as its device
311306
# block is not overwritten; and it will be safe to skip saving them
312307
# to host xfer buffer.
313-
if full_block_ids:
308+
if block_ids:
314309
self._reqs_need_save[request.request_id] = \
315-
(request, full_block_ids)
310+
(request, block_ids)
316311
elif params.get("do_remote_prefill"):
317312
if params.get("remote_block_ids"):
318313
if all(p in params for p in ("remote_engine_id", "remote_host",
@@ -401,12 +396,9 @@ def request_finished(
401396
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
402397
return False, None
403398

404-
# Get computed blocks.
405-
all_full = request.num_computed_tokens % self.block_size == 0
406-
computed_block_ids = block_ids if all_full else block_ids[:-1]
407-
408-
# If prompt < block_size, no xfer so free blocks immediately.
409-
delay_free_blocks = len(computed_block_ids) > 0
399+
# TODO: check whether block_ids actually ever be 0. If not we could
400+
# remove the conditional below
401+
delay_free_blocks = len(block_ids) > 0
410402

411403
if delay_free_blocks:
412404
# Prefill request on remote. It will be read from D upon completion
@@ -416,7 +408,7 @@ def request_finished(
416408
return delay_free_blocks, dict(
417409
do_remote_prefill=True,
418410
do_remote_decode=False,
419-
remote_block_ids=computed_block_ids,
411+
remote_block_ids=block_ids,
420412
remote_engine_id=self.engine_id,
421413
remote_host=self.side_channel_host,
422414
remote_port=self.side_channel_port,

0 commit comments

Comments
 (0)