Skip to content

Commit aa29050

Browse files
[bugfix] pass callbacks arg for ppo_trainer (#5637)
1 parent ab2133b commit aa29050

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

swift/trainers/rlhf_trainer/ppo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, *
3939
new_kwargs = {
4040
k: v
4141
for k, v in kwargs.items()
42-
if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset']
42+
if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset', 'callbacks']
4343
}
4444
parameters = inspect.signature(ppo_trainer_init).parameters
4545
if 'config' in parameters:

0 commit comments

Comments
 (0)