Skip to content

Commit 65f9ef1

Browse files
authored
add result_dir paramerter to InferArgument (#1561)
1 parent c222140 commit 65f9ef1

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

swift/llm/infer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,15 @@ def llm_infer(args: InferArguments) -> Dict[str, List[Dict[str, Any]]]:
325325
result: List[Dict[str, Any]] = []
326326
jsonl_path = None
327327
if args.save_result:
328-
result_dir = args.ckpt_dir
329-
if result_dir is None:
330-
result_dir = llm_engine.model_dir if args.infer_backend in {'vllm', 'lmdeploy'} else model.model_dir
328+
if args.result_dir:
329+
result_dir = args.result_dir
330+
else:
331+
result_dir = args.ckpt_dir
332+
if result_dir is None:
333+
result_dir = llm_engine.model_dir if args.infer_backend in {'vllm', 'lmdeploy'} else model.model_dir
334+
if result_dir is not None:
335+
result_dir = os.path.join(result_dir, 'infer_result')
331336
if result_dir is not None:
332-
result_dir = os.path.join(result_dir, 'infer_result')
333337
os.makedirs(result_dir, exist_ok=True)
334338
time = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
335339
jsonl_path = os.path.join(result_dir, f'{time}.jsonl')

swift/llm/utils/argument.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,7 @@ class InferArguments(ArgumentsBase):
11631163
default='AUTO', metadata={'help': f"template_type choices: {list(TEMPLATE_MAPPING.keys()) + ['AUTO']}"})
11641164
infer_backend: Literal['AUTO', 'vllm', 'pt', 'lmdeploy'] = 'AUTO'
11651165
ckpt_dir: Optional[str] = field(default=None, metadata={'help': '/path/to/your/vx-xxx/checkpoint-xxx'})
1166+
result_dir: Optional[str] = field(default=None, metadata={'help': '/path/to/your/infer_result'})
11661167
load_args_from_ckpt_dir: bool = True
11671168
load_dataset_config: bool = False
11681169
eval_human: Optional[bool] = None

swift/llm/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def _prepare_inputs(model: PreTrainedModel,
613613
if max_length and token_len + generation_config.max_new_tokens > max_length:
614614
generation_config.max_new_tokens = max_length - token_len
615615
if generation_config.max_new_tokens <= 0:
616-
raise AssertionError('Current sentence length exceeds' f'the model max_length: {max_length}')
616+
raise AssertionError(f'Current sentence length exceeds the model max_length: {max_length}')
617617
if template.suffix[-1] not in stop_words:
618618
stop_words.append(template.suffix[-1])
619619
inputs = to_device(inputs, device)

0 commit comments

Comments
 (0)