Skip to content

Commit 10916b2

Browse files
authored
[PD] Fix of requests occasionally missing issue in async transfer (SW-234952) (#1978)
1 parent 715e3c1 commit 10916b2

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

benchmarks/benchmark_serving.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,12 @@ def sample_random_requests(
374374
size=prefix_len).tolist()
375375

376376
input_lens = np.random.randint(
377-
int(input_len * range_ratio),
377+
max(1, int(input_len * range_ratio)), # At least 1 input token
378378
input_len + 1,
379379
size=num_prompts,
380380
)
381381
output_lens = np.random.randint(
382-
int(output_len * range_ratio),
382+
max(1, int(output_len * range_ratio)), # At least 1 output token
383383
output_len + 1,
384384
size=num_prompts,
385385
)

vllm/core/scheduler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,13 +581,18 @@ def put_to_shared_dict(prefix, kv_cache, hidden_states):
581581

582582
self.scheduler_profiler.start('internal', 'fetching_kv')
583583
hash_prefix = hash_list(seq_group.prompt_token_ids)
584-
prefix, kv_cache, hidden_states = get_kv_and_hidden_states(
585-
hash_prefix)
586-
if kv_cache is not None:
584+
if len(seq_group.prompt_token_ids) == 1:
585+
# This is a padding seq. Won't be able to fetch KV. skip it.
586+
logger.info("seq len is 1, skip fetching kv...")
587587
fetching_success = True
588-
put_to_shared_dict(prefix, kv_cache, hidden_states)
589588
else:
590-
fetching_success = False
589+
prefix, kv_cache, hidden_states = get_kv_and_hidden_states(
590+
hash_prefix)
591+
if kv_cache is not None:
592+
fetching_success = True
593+
put_to_shared_dict(prefix, kv_cache, hidden_states)
594+
else:
595+
fetching_success = False
591596
self.fetching_done.put((seq_group, fetching_success))
592597
self.fetching_queue.task_done()
593598
self.scheduler_profiler.end()

vllm/worker/hpu_model_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2946,14 +2946,21 @@ def async_recv_kv_caches(model, model_input, attn_metadata,
29462946
model_input.attn_metadata.seq_lens_tensor
29472947
seq_lens = seq_lens_tensor.tolist() #2D list
29482948
hidden_states_list = []
2949+
HIDDEN_SHAPE = (1, 7168)
29492950
start_block_idx = 0
29502951
k_v_head_size = 576
29512952
bypass_model_exec = True
29522953
htorch.core.mark_step()
29532954
for idx, slen in enumerate(seq_lens):
29542955
if slen == 1:
2955-
hidden_states_list.append(
2956-
hidden_states_list[0])
2956+
if hidden_states_list:
2957+
hidden_states_list.append(
2958+
hidden_states_list[0])
2959+
else:
2960+
logger.warning("The first seq len is 1")
2961+
dummy_hidden = torch.zeros(HIDDEN_SHAPE, device="hpu")
2962+
hidden_states_list.append(
2963+
dummy_hidden)
29572964
# skip the seq with only one token
29582965
continue
29592966
num_blocks = (slen + self.block_size -

0 commit comments

Comments
 (0)