Skip to content

Commit a098e70

Browse files
Luka-Dwillmjkmehant
authored
fix: Move deprecated positional arguments from SFTTrainer to SFTConfig (#399)
* fix: set legacy behavior to false, enable new behavior Signed-off-by: Will Johnson <[email protected]> * fix: Resolve push_to_hub_token warning Signed-off-by: Will Johnson <[email protected]> * fix: Remove max_seq_length and dataset_text_field from SFTTrainer Signed-off-by: Will Johnson <[email protected]> * fmt Signed-off-by: Will Johnson <[email protected]> * fix: Resolve tokenizer.padding_side warning Signed-off-by: Will Johnson <[email protected]> * nit: restructure warning fixes Signed-off-by: Will Johnson <[email protected]> * fix: Add packing directly to SFTConfig Signed-off-by: Will Johnson <[email protected]> * fmt Signed-off-by: Will Johnson <[email protected]> * Removed dataset_kwargs from SFTTrainer Removed the argument dataset_kwargs from the the invocation of SFTTRainer() because it will be deprecated in V1.0.0. Instead, dataset_kwargs have been added as a key to the training_args variable. Following the example provided by HF found here: https://huggingface.co/docs/trl/en/sft_trainer#training-the-vision-language-model Signed-off-by: Luka Dojcinovic <[email protected]> * fix: Added max_seq_length back to SFTConfig() Signed-off-by: Luka Dojcinovic <[email protected]> * Removed legacy and padding_side args Removed these args as they were based on changes from @willmj that haven't been approved yet Signed-off-by: Luka Dojcinovic <[email protected]> * Moved all args to additional_args Following @kmehant suggestion. Signed-off-by: Luka Dojcinovic <[email protected]> * Removed packing and max_seq_length Removed packing and max_seq_length variables from additional_args Signed-off-by: Luka Dojcinovic <[email protected]> * Removed check is_pretokenized_dataset Co-authored-by: Mehant Kammakomati <[email protected]> Signed-off-by: Luka-D <[email protected]> * Removed max_seq_length from additional_args Signed-off-by: Luka Dojcinovic <[email protected]> * Removed error.log Signed-off-by: Luka Dojcinovic <[email protected]> * fix: move packing to SFTConfig as well Co-authored-by: Luka-D <[email protected]> Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Will Johnson <[email protected]> Signed-off-by: Luka Dojcinovic <[email protected]> Signed-off-by: Luka-D <[email protected]> Signed-off-by: Mehant Kammakomati <[email protected]> Co-authored-by: Will Johnson <[email protected]> Co-authored-by: Mehant Kammakomati <[email protected]> Co-authored-by: Mehant Kammakomati <[email protected]>
1 parent 689ee41 commit a098e70

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tuning/sft_trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,27 +318,29 @@ def train(
318318
# this validation, we just drop the things that aren't part of the SFT Config and build one
319319
# from our object directly. In the future, we should consider renaming this class and / or
320320
# not adding things that are not directly used by the trainer instance to it.
321+
321322
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
322323
transformer_kwargs = {
323324
k: v
324325
for k, v in train_args.to_dict().items()
325326
if k in transformer_train_arg_fields
326327
}
327-
training_args = SFTConfig(**transformer_kwargs)
328+
329+
additional_args = {
330+
"dataset_text_field": dataset_text_field,
331+
"dataset_kwargs": dataset_kwargs,
332+
}
333+
training_args = SFTConfig(**transformer_kwargs, **additional_args)
328334

329335
trainer = SFTTrainer(
330336
model=model,
331337
tokenizer=tokenizer,
332338
train_dataset=formatted_train_dataset,
333339
eval_dataset=formatted_validation_dataset,
334-
packing=train_args.packing,
335340
data_collator=data_collator,
336-
dataset_text_field=dataset_text_field,
337341
args=training_args,
338-
max_seq_length=max_seq_length,
339342
callbacks=trainer_callbacks,
340343
peft_config=peft_config,
341-
dataset_kwargs=dataset_kwargs,
342344
)
343345

344346
# We track additional metrics and experiment metadata after trainer object creation

0 commit comments

Comments
 (0)