Skip to content

Commit 89d43c8

Browse files
authored
fix: do not set model max length when loading model (#21)
* fix: do not set the model max length when loading model * fix log message to use proper train args value
1 parent c23cc9b commit 89d43c8

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tuning/sft_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def train(
7777
tokenizer = transformers.AutoTokenizer.from_pretrained(
7878
model_args.model_name_or_path,
7979
cache_dir=train_args.cache_dir,
80-
model_max_length=train_args.model_max_length,
8180
padding_side="right",
8281
use_fast = True
8382
)
@@ -96,7 +95,7 @@ def train(
9695
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
9796
logger.info(f"Model max length {model_max_length}")
9897
if train_args.model_max_length > tokenizer.model_max_length:
99-
logger.warning(f"model_max_length {model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}")
98+
logger.warning(f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}")
10099

101100
# TODO: we need to change this, perhaps follow what open instruct does?
102101
special_tokens_dict = dict()

0 commit comments

Comments
 (0)