Skip to content

Commit 6db8f23

Browse files
authored
[bugfix] fix accuracy problem when chunked prefill (#438)
* fix accuracy problem when chunked prefill
1 parent 9e6a315 commit 6db8f23

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class RequestMeta:
3939
hbm_hit_block_num: int = 0
4040
# local_computed_block + external_computed_block
4141
total_hit_block_num: int = 0
42+
num_token_ids: int = 0
43+
vllm_block_ids: list[int] = field(default_factory=list)
44+
token_processed: int = 0
4245

4346

4447
@dataclass
@@ -207,7 +210,6 @@ def get_num_new_matched_tokens(
207210
request: "Request",
208211
num_computed_tokens: int,
209212
) -> tuple[int, bool]:
210-
211213
assert num_computed_tokens % self.block_size == 0
212214
hbm_hit_block_num = num_computed_tokens // self.block_size
213215

@@ -242,13 +244,16 @@ def get_num_new_matched_tokens(
242244
# When all the tokens are cached in ssd or hbm,
243245
# we need to recompute the last token. This if condition will be removed
244246
# once vLLM scheduler provides a better solution in the future.
245-
if total_hit_block_num * self.block_size == request.num_tokens:
247+
num_total_hit_tokens = total_hit_block_num * self.block_size
248+
if num_total_hit_tokens == request.num_tokens:
246249
external_hit_tokens -= 1
247250

248251
self.requests_meta[request.request_id] = RequestMeta(
249252
ucm_block_ids=ucm_block_ids,
250253
hbm_hit_block_num=hbm_hit_block_num,
251254
total_hit_block_num=total_hit_block_num,
255+
num_token_ids=len(request.all_token_ids),
256+
token_processed=num_total_hit_tokens,
252257
)
253258

254259
return external_hit_tokens, False
@@ -277,22 +282,16 @@ def _generate_dispatch_meta(
277282
| scheduled_block_num |
278283
"""
279284

280-
new_blocks_num = new_tokens // self.block_size
281285
hbm_hit_block_num = req_meta.hbm_hit_block_num
282286
total_hit_block_num = req_meta.total_hit_block_num
283-
scheduled_block_num = total_hit_block_num + new_blocks_num
284287
ucm_block_ids = req_meta.ucm_block_ids
288+
req_meta.vllm_block_ids.extend(vllm_block_ids)
285289

286-
dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num]
287-
if need_load:
288-
dump_vllm_block_ids = vllm_block_ids[
289-
total_hit_block_num:scheduled_block_num
290-
]
291-
else:
292-
dump_vllm_block_ids = vllm_block_ids
293-
294-
# after this round, req_meta will be updated
295-
req_meta.total_hit_block_num = scheduled_block_num
290+
start_idx = req_meta.token_processed // self.block_size
291+
end_idx = (req_meta.token_processed + new_tokens) // self.block_size
292+
dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx]
293+
dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx]
294+
req_meta.token_processed += new_tokens
296295

297296
load_ucm_block_ids, load_vllm_block_ids = [], []
298297
if need_load:
@@ -332,10 +331,13 @@ def build_connector_meta(
332331
continue
333332
req_meta = self.requests_meta.get(request_id)
334333
if req_meta:
334+
new_block_ids = []
335+
if scheduled_cached_reqs.new_block_ids[i] != None:
336+
new_block_ids = scheduled_cached_reqs.new_block_ids[i][0]
335337
requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
336338
req_meta,
337339
scheduler_output.num_scheduled_tokens[request_id],
338-
scheduled_cached_reqs.new_block_ids[i][0],
340+
new_block_ids,
339341
scheduled_cached_reqs.resumed_from_preemption[i],
340342
)
341343
else:

0 commit comments

Comments
 (0)