Skip to content

Commit 3605cb8

Browse files
authored
fix dataset_sample & deploy stop_words (#1385)
1 parent f8faae8 commit 3605cb8

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

swift/llm/deploy.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,15 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
184184
request_id = request_info['request_id']
185185

186186
kwargs = {'max_new_tokens': request.max_tokens}
187-
for key in ['n', 'stop', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
187+
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
188188
kwargs[key] = getattr(request, key)
189189
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
190190
new_value = getattr(request, key)
191191
if new_value is None:
192192
kwargs[key] = getattr(llm_engine.generation_config, key)
193193
else:
194194
kwargs[key] = new_value
195+
kwargs['stop'] = (llm_engine.generation_config.stop or []) + (getattr(request, 'stop') or [])
195196

196197
generation_config = VllmGenerationConfig(**kwargs)
197198
if generation_config.use_beam_search and request.stream:
@@ -343,7 +344,7 @@ def __repr__(self) -> str:
343344

344345
@torch.inference_mode()
345346
async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionRequest], raw_request: Request):
346-
global model, template
347+
global model, template, _args
347348
result = await _prepare_request(request)
348349
if isinstance(result, JSONResponse):
349350
return result
@@ -359,8 +360,13 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
359360
new_value = getattr(request, key)
360361
if new_value is None:
361362
kwargs[key] = getattr(model.generation_config, key)
363+
if key == 'temperature':
364+
do_sample = getattr(model.generation_config, 'do_sample')
365+
if not do_sample:
366+
kwargs[key] = 0
362367
else:
363368
kwargs[key] = new_value
369+
364370
if kwargs['temperature'] == 0:
365371
kwargs['do_sample'] = False
366372
kwargs['temperature'] = 1
@@ -374,7 +380,8 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
374380
set_generation_config(model, generation_config) # inplace
375381
model.generation_config = _old_generation_config
376382
request_info['generation_config'] = generation_config
377-
request_info.update({'seed': request.seed, 'stop': request.stop, 'stream': request.stream})
383+
stop = (_args.stop_words or []) + (getattr(request, 'stop') or [])
384+
request_info.update({'seed': request.seed, 'stop': stop, 'stream': request.stream})
378385
logger.info(request_info)
379386

380387
created_time = int(time.time())
@@ -397,7 +404,7 @@ async def _generate_full():
397404
model,
398405
template,
399406
**example,
400-
stop_words=request.stop,
407+
stop_words=stop,
401408
generation_config=generation_config,
402409
generation_info=generation_info,
403410
**adapter_kwargs)
@@ -441,7 +448,7 @@ def _generate_stream():
441448
model,
442449
template,
443450
**example,
444-
stop_words=request.stop,
451+
stop_words=stop,
445452
generation_config=generation_config,
446453
generation_info=generation_info,
447454
**adapter_kwargs)

swift/llm/utils/argument.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ class InferArguments(ArgumentsBase):
11171117
top_p: float = 0.7
11181118
repetition_penalty: float = 1.
11191119
num_beams: int = 1
1120-
stop_words: List[str] = None
1120+
stop_words: List[str] = field(default_factory=list)
11211121

11221122
# rope-scaling
11231123
rope_scaling: Literal['linear', 'dynamic'] = None

swift/llm/utils/dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,10 @@ def sample_dataset(dataset: HfDataset, dataset_sample: int, random_state: Option
341341
return dataset
342342
if random_state is None:
343343
random_state = RandomState()
344-
# Sample the part that exceeds the length of the dataset.
345-
idx = random_state.permutation(len(dataset))[:dataset_sample]
346-
dataset_sample -= len(idx)
347-
if dataset_sample > 0:
348-
idx2 = random_state.choice(len(dataset), dataset_sample)
349-
idx = np.concatenate([idx, idx2], axis=0)
344+
345+
idx_repeat = np.tile(range(len(dataset)), dataset_sample // len(dataset))
346+
idx_random = random_state.permutation(len(dataset))[:dataset_sample % len(dataset)]
347+
idx = np.concatenate([idx_repeat, idx_random])
350348
dataset = dataset.select(idx)
351349
return dataset
352350

0 commit comments

Comments
 (0)