5151from ..core import masked_mean , masked_whiten
5252from ..models import create_reference_model
5353from ..models .utils import unwrap_model_for_generation
54- from ..trainer .utils import (
54+ from .ppo_config import PPOConfig
55+ from .utils import (
5556 OnlineTrainerState ,
5657 batch_generation ,
5758 disable_dropout_in_model ,
5859 exact_div ,
5960 first_true_indices ,
6061 forward ,
62+ generate_model_card ,
6163 get_reward ,
64+ peft_module_casting_to_bf16 ,
6265 prepare_deepspeed ,
6366 print_rich_table ,
6467 truncate_response ,
6568)
66- from .ppo_config import PPOConfig
67- from .utils import generate_model_card , peft_module_casting_to_bf16
6869
6970
7071if is_peft_available ():
@@ -97,10 +98,11 @@ def forward(self, **kwargs):
9798class PPOTrainer (Trainer ):
9899 _tag_names = ["trl" , "ppo" ]
99100
101+ @deprecate_kwarg ("config" , new_name = "args" , version = "0.15.0" , raise_if_both_names = True )
100102 @deprecate_kwarg ("tokenizer" , new_name = "processing_class" , version = "0.15.0" , raise_if_both_names = True )
101103 def __init__ (
102104 self ,
103- config : PPOConfig ,
105+ args : PPOConfig ,
104106 processing_class : Optional [
105107 Union [PreTrainedTokenizerBase , BaseImageProcessor , FeatureExtractionMixin , ProcessorMixin ]
106108 ],
@@ -122,8 +124,7 @@ def __init__(
122124 "same as `policy`, you must make a copy of it, or `None` if you use peft."
123125 )
124126
125- self .args = config
126- args = config
127+ self .args = args
127128 self .processing_class = processing_class
128129 self .policy = policy
129130
0 commit comments