Skip to content

Commit 4040502

Browse files
authored
fix load_from_ckpt_dir bug (#161)
1 parent 85c2592 commit 4040502

File tree

8 files changed

+9
-9
lines changed

8 files changed

+9
-9
lines changed

examples/pytorch/llm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,4 +614,4 @@ The template initialization function retrieves the complete chat template based
614614
- `--ignore_args_error`: Default value is `False`. For specific parameter details, please refer to the `sft.sh Command Line Arguments`.
615615
- `--stream`: Whether to use streaming output. Default value is `True`.
616616
- `--merge_lora_and_save`: Whether to merge the lora weights into the base model and save the complete weights. Default value is `False`. The weights will be saved in a directory named `checkpoint-xxx-merged` at the same level as `ckpt_dir`, e.g., `'/path/to/your/vx_xxx/checkpoint-xxx-merged'`.
617-
- `--overwrite_generation_config`: Whether to save the generation_config used for evaluation as a `generation_config.json` file. Default value is `False`. The generate_config file saved during training will be overwritten.
617+
- `--overwrite_generation_config`: Whether to save the generation_config used for evaluation as a `generation_config.json` file. Default value is `False`. The generation_config file saved during training will be overwritten.

examples/pytorch/llm/README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,4 +617,4 @@ if __name__ == '__main__':
617617
- `--ignore_args_error`: 默认值为`False`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.
618618
- `--stream`: 是否使用流式输出, 默认为`True`.
619619
- `--merge_lora_and_save`: 是否将lora权重merge到基模型中, 并保存完整的权重, 默认为`False`. 权重会保存在`ckpt_dir`的同级目录中, e.g. `'/path/to/your/vx_xxx/checkpoint-xxx-merged'`目录下.
620-
- `--overwrite_generation_config`: 是否将评估所使用的generation_config保存成`generation_config.json`文件, 默认为`False`. 训练时保存的generate_config文件将被覆盖.
620+
- `--overwrite_generation_config`: 是否将评估所使用的generation_config保存成`generation_config.json`文件, 默认为`False`. 训练时保存的generation_config文件将被覆盖.

swift/llm/infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def prepare_model_template(
110110
args.system, args.max_length,
111111
args.truncation_strategy)
112112
generation_config = GenerationConfig(
113-
max_length=None,
114113
max_new_tokens=args.max_new_tokens,
115114
temperature=args.temperature,
116115
top_k=args.top_k,

swift/llm/rome.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def rome_infer(args: RomeArguments) -> None:
5959
args.system, args.max_length,
6060
args.truncation_strategy)
6161
generation_config = GenerationConfig(
62-
max_length=None,
6362
max_new_tokens=args.max_new_tokens,
6463
temperature=args.temperature,
6564
top_k=args.top_k,

swift/llm/sft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def llm_sft(args: SftArguments) -> str:
164164

165165
# Setting training_args
166166
generation_config = GenerationConfig(
167-
max_length=None,
168167
max_new_tokens=args.max_new_tokens,
169168
temperature=args.temperature,
170169
top_k=args.top_k,

swift/llm/utils/argument.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,8 @@ def load_from_ckpt_dir(args: InferArguments) -> None:
537537
with open(sft_args_path, 'r') as f:
538538
sft_args = json.load(f)
539539
imported_keys = [
540-
'model_id_or_path', 'model_revision', 'model_cache_dir', 'sft_type',
541-
'template_type', 'dtype', 'system', 'quantization_bit',
540+
'model_type', 'model_id_or_path', 'model_revision', 'model_cache_dir',
541+
'sft_type', 'template_type', 'dtype', 'system', 'quantization_bit',
542542
'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type',
543543
'bnb_4bit_use_double_quant'
544544
]

swift/llm/utils/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def inference_stream(
350350
model.__class__.sample_stream = NewGenerationMixin.sample_stream
351351
stream_config = StreamGenerationConfig(
352352
**generation_config.to_dict(), do_stream=True)
353-
stream_config.max_length = int(1e9) # fix max_length, max_new_tokens bug
353+
if stream_config.max_new_tokens is not None:
354+
stream_config.max_length = 20 # fix max_length, max_new_tokens bug
354355
stream_config.do_sample = True # avoid is_greedy_gen_mode = True
355356
gen = model.generate_stream(
356357
input_ids=input_ids,
@@ -395,6 +396,8 @@ def inference(model: PreTrainedModel,
395396
streamer = None
396397
if stream:
397398
streamer = TextStreamer(tokenizer, skip_prompt=True)
399+
if generation_config.max_new_tokens is not None:
400+
generation_config.max_length = 20 # fix max_length, max_new_tokens bug
398401
generate_ids = model.generate(
399402
input_ids=input_ids,
400403
attention_mask=attention_mask,

swift/trainers/trainers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def prediction_step(
107107
gen_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
108108
# fix generate warning
109109
if ('max_length' in gen_kwargs and 'max_new_tokens' in gen_kwargs
110-
and gen_kwargs['max_length'] is None):
110+
and gen_kwargs['max_new_tokens'] is not None):
111111
gen_kwargs.pop('max_length')
112112
gen_time = time.time()
113113
generate_inputs = inputs.copy()

0 commit comments

Comments
 (0)