Skip to content

Commit 7b93555

Browse files
authored
[python] Update rolling batch params to output delta (#2636)
1 parent d613c76 commit 7b93555

File tree

4 files changed

+21
-133
lines changed

4 files changed

+21
-133
lines changed

engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lmi_dist.arg_utils import VllmEngineArgs
2020
from lmi_dist.init_engine import engine_from_args
2121
from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor
22-
from vllm import SamplingParams
22+
from vllm.sampling_params import RequestOutputKind
2323
from vllm.utils import AtomicCounter
2424

2525
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
@@ -153,6 +153,7 @@ def translate_lmi_dist_params(self, parameters: dict):
153153
154154
:return: The same parameters dict, but with lmi-dist style parameter names.
155155
"""
156+
parameters["output_kind"] = RequestOutputKind.DELTA
156157
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
157158
do_sample = parameters.pop("do_sample", None)
158159
if do_sample is not None and do_sample is False:

engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def update_request_cache_with_output(request_cache: OrderedDict,
7474
request_output.prompt_tokens_details.append(prompt_token)
7575

7676
# sets the details of all sequences
77-
update_multiple_sequences(cache, request_output, vllm_request_output)
77+
update_multiple_sequences(request_output, vllm_request_output)
7878

7979
# remove finished requests from cache
8080
if vllm_request_output.finished:
@@ -89,49 +89,28 @@ def update_request_cache_with_output(request_cache: OrderedDict,
8989
return request_cache
9090

9191

92-
def update_multiple_sequences(cache, request_output, vllm_request_output):
92+
def update_multiple_sequences(request_output, vllm_request_output):
9393
for completion_output in vllm_request_output.outputs:
94-
9594
sequence_index = completion_output.index
96-
if f"sequence_index_{sequence_index}" not in cache:
97-
cache[f"sequence_index_{sequence_index}"] = {
98-
"curr_length": 0,
99-
"num_generated_tokens": 0
100-
}
10195

10296
if sequence_index not in request_output.sequences:
10397
request_output.sequences[sequence_index] = Sequence()
10498

105-
# set token of the sequence
106-
# previous length of token ids generated
107-
prev_len = cache[f"sequence_index_{sequence_index}"][
108-
'num_generated_tokens']
109-
# curr length of the token ids generated so far
110-
cur_len = len(completion_output.token_ids)
111-
cache[f"sequence_index_{sequence_index}"][
112-
"num_generated_tokens"] = cur_len
113-
11499
# get the newly generated token_ids
115-
new_token_ids = completion_output.token_ids[
116-
prev_len:
117-
cur_len] if prev_len < cur_len else completion_output.token_ids
100+
new_token_ids = completion_output.token_ids
118101

119102
# get the newly generated token texts for speculative decoding
120103
output_token_texts = []
121104
if hasattr(completion_output, "output_token_texts"):
122-
output_token_texts = completion_output.output_token_texts[
123-
prev_len:
124-
cur_len] if prev_len < cur_len else completion_output.output_token_texts
105+
output_token_texts = completion_output.output_token_texts
125106

126107
top_tokens = []
127108
token_texts = []
128109
# calculate log probs and token_texts
129110
if completion_output.logprobs:
130-
new_logprobs_list = completion_output.logprobs[
131-
prev_len:
132-
cur_len] if prev_len < cur_len else completion_output.logprobs
133111
new_logprobs = []
134-
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
112+
for token_id, logprobs in zip(new_token_ids,
113+
completion_output.logprobs):
135114
new_logprobs.append(logprobs[token_id].logprob)
136115
decoded_token = logprobs[token_id].decoded_token if logprobs[
137116
token_id].decoded_token else ""
@@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
141120
Token(id=token_id_key,
142121
text=logprob.decoded_token,
143122
log_prob=logprob.logprob))
144-
145123
elif new_token_ids:
146124
# TODO: Test and remove this. logprobs is always set 1. This case should never happen.
147125
new_logprobs = [None] * len(new_token_ids)
148-
curr_length = cache[f"sequence_index_{sequence_index}"][
149-
"curr_length"]
150-
token_texts.append(completion_output.text[curr_length:])
126+
token_texts.append(completion_output.text)
151127

152128
if not output_token_texts:
153129
if len(token_texts) != len(new_token_ids):
@@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
186162
request_output.sequences[sequence_index].set_next_top_tokens(
187163
top_tokens)
188164

189-
cache[f"sequence_index_{sequence_index}"]["curr_length"] = len(
190-
completion_output.text)
191-
192165

193166
def get_speculative_decoding_metrics_record(
194167
completion_output: CompletionOutput,

engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from collections import OrderedDict, defaultdict
1515

1616
from vllm import LLMEngine, SamplingParams
17+
from vllm.sampling_params import RequestOutputKind
1718
from vllm.utils import random_uuid, AtomicCounter
1819

1920
from djl_python.request import Request
@@ -85,6 +86,7 @@ def translate_vllm_params(self, parameters: dict) -> dict:
8586
8687
:return: The same parameters dict, but with VLLM style parameter names.
8788
"""
89+
parameters["output_kind"] = RequestOutputKind.DELTA
8890
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
8991
do_sample = parameters.pop("do_sample", None)
9092
if do_sample is not None and do_sample is False:

engines/python/setup/djl_python/tests/test_rb_vllm_utils.py

Lines changed: 10 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import sys
22
import unittest
3-
import uuid
43
from dataclasses import dataclass
54
from typing import List, Optional, Dict, Union
65
from collections import OrderedDict
@@ -12,8 +11,8 @@
1211
import djl_python
1312
from djl_python.output_formatter import _json_output_formatter
1413
from djl_python.request import Request
15-
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput
16-
'''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1'''
14+
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token
15+
'''These Mock classes are in compliance with vllm RequestOutput version 0.6.3.post1'''
1716

1817

1918
@dataclass
@@ -148,23 +147,10 @@ def __init__(
148147
],
149148
outputs=[
150149
MockCompletionOutput(index=1,
151-
text=' member of',
152-
token_ids=[4292, 302],
150+
text=' of',
151+
token_ids=[302],
153152
cumulative_logprob=-4.3041129764169455,
154153
logprobs=[{
155-
4292:
156-
MockLogprob(logprob=-4.2740092277526855,
157-
rank=4,
158-
decoded_token=' member'),
159-
2032:
160-
MockLogprob(logprob=-3.0240092277526855,
161-
rank=1,
162-
decoded_token=' big'),
163-
888:
164-
MockLogprob(logprob=-4.4099884033203125,
165-
rank=3,
166-
decoded_token=' new'),
167-
}, {
168154
302:
169155
MockLogprob(logprob=-0.03010374866425991,
170156
rank=1,
@@ -181,27 +167,10 @@ def __init__(
181167
finish_reason=None,
182168
stop_reason=None),
183169
MockCompletionOutput(index=0,
184-
text=' consolidated',
185-
token_ids=[22968, 601],
170+
text='ated',
171+
token_ids=[601],
186172
cumulative_logprob=-13.402491569519043,
187173
logprobs=[{
188-
22968:
189-
MockLogprob(logprob=-12.117759704589844,
190-
rank=5308,
191-
decoded_token=' consolid'),
192-
2032:
193-
MockLogprob(logprob=-3.0240092277526855,
194-
rank=1,
195-
decoded_token=' big'),
196-
17372:
197-
MockLogprob(logprob=-13.409988403320312,
198-
rank=10489,
199-
decoded_token=' crown'),
200-
888:
201-
MockLogprob(logprob=-4.4099884033203125,
202-
rank=3,
203-
decoded_token=' new'),
204-
}, {
205174
601:
206175
MockLogprob(logprob=-1.2847318649291992,
207176
rank=2,
@@ -235,37 +204,10 @@ def __init__(
235204
],
236205
outputs=[
237206
MockCompletionOutput(index=1,
238-
text=' member of the',
239-
token_ids=[4292, 302,
240-
272],
207+
text=' the',
208+
token_ids=[272],
241209
cumulative_logprob=-4.815703457221389,
242210
logprobs=[{
243-
4292:
244-
MockLogprob(logprob=-4.2740092277526855,
245-
rank=4,
246-
decoded_token=' member'),
247-
2032:
248-
MockLogprob(logprob=-3.0240092277526855,
249-
rank=1,
250-
decoded_token=' big'),
251-
888:
252-
MockLogprob(logprob=-4.4099884033203125,
253-
rank=3,
254-
decoded_token=' new'),
255-
}, {
256-
302:
257-
MockLogprob(logprob=-0.03010374866425991,
258-
rank=1,
259-
decoded_token=' of'),
260-
235290:
261-
MockLogprob(logprob=-2.2026185989379883,
262-
rank=1,
263-
decoded_token='-'),
264-
578:
265-
MockLogprob(logprob=-2.2026185989379883,
266-
rank=2,
267-
decoded_token=' and')
268-
}, {
269211
272:
270212
MockLogprob(logprob=-0.5115904808044434,
271213
rank=1,
@@ -282,40 +224,10 @@ def __init__(
282224
finish_reason='length',
283225
stop_reason=None),
284226
MockCompletionOutput(index=0,
285-
text=' consolidated or',
286-
token_ids=[22968, 601, 442],
227+
text=' or',
228+
token_ids=[442],
287229
cumulative_logprob=-20.4010648727417,
288230
logprobs=[{
289-
22968:
290-
MockLogprob(logprob=-12.117759704589844,
291-
rank=5308,
292-
decoded_token=' consolid'),
293-
2032:
294-
MockLogprob(logprob=-3.0240092277526855,
295-
rank=1,
296-
decoded_token=' big'),
297-
17372:
298-
MockLogprob(logprob=-13.409988403320312,
299-
rank=10489,
300-
decoded_token=' crown'),
301-
888:
302-
MockLogprob(logprob=-4.4099884033203125,
303-
rank=3,
304-
decoded_token=' new'),
305-
}, {
306-
601:
307-
MockLogprob(logprob=-1.2847318649291992,
308-
rank=2,
309-
decoded_token='ated'),
310-
1028:
311-
MockLogprob(logprob=-0.909731924533844,
312-
rank=1,
313-
decoded_token='ator'),
314-
1162:
315-
MockLogprob(logprob=-0.8929234743118286,
316-
rank=2,
317-
decoded_token=' year')
318-
}, {
319231
442:
320232
MockLogprob(logprob=-6.998573303222656,
321233
rank=188,

0 commit comments

Comments
 (0)