@@ -88,9 +88,7 @@ def __init__(self, policy, value_model) -> None:
8888 self .critic_backbone = getattr (value_model , value_model .base_model_prefix )
8989
9090 def forward (self , ** kwargs ):
91- output = self .critic_backbone (
92- ** kwargs ,
93- )
91+ output = self .critic_backbone (** kwargs )
9492 logits = self .value_model .score (output .hidden_states [- 1 ])
9593 return self .policy (** kwargs ), logits
9694
@@ -100,14 +98,18 @@ class PPOTrainer(Trainer):
10098
10199 @deprecate_kwarg ("config" , new_name = "args" , version = "0.15.0" , raise_if_both_names = True )
102100 @deprecate_kwarg ("tokenizer" , new_name = "processing_class" , version = "0.15.0" , raise_if_both_names = True )
101+ @deprecate_kwarg ("policy" , "0.15.0" , "model" , warn_if_greater_or_equal_version = True , raise_if_both_names = True )
102+ @deprecate_kwarg (
103+ "ref_policy" , "0.15.0" , "ref_model" , warn_if_greater_or_equal_version = True , raise_if_both_names = True
104+ )
103105 def __init__ (
104106 self ,
105107 args : PPOConfig ,
106108 processing_class : Optional [
107109 Union [PreTrainedTokenizerBase , BaseImageProcessor , FeatureExtractionMixin , ProcessorMixin ]
108110 ],
109- policy : nn .Module ,
110- ref_policy : Optional [nn .Module ],
111+ model : nn .Module ,
112+ ref_model : Optional [nn .Module ],
111113 reward_model : nn .Module ,
112114 train_dataset : Dataset ,
113115 value_model : Optional [nn .Module ] = None ,
@@ -118,24 +120,24 @@ def __init__(
118120 callbacks : Optional [List [TrainerCallback ]] = None ,
119121 peft_config : Optional ["PeftConfig" ] = None ,
120122 ) -> None :
121- if ref_policy is policy :
123+ if ref_model is model :
122124 raise ValueError (
123- "`policy ` and `ref_policy ` cannot be the same object. If you want `ref_policy ` to be the "
124- "same as `policy `, you must make a copy of it, or `None` if you use peft."
125+ "`model ` and `ref_model ` cannot be the same object. If you want `ref_model ` to be the "
126+ "same as `model `, you must make a copy of it, or `None` if you use peft."
125127 )
126128
127129 self .args = args
128130 self .processing_class = processing_class
129- self .policy = policy
131+ self .model = model
130132
131133 # Define the collator if not provided
132134 if data_collator is None :
133135 data_collator = DataCollatorWithPadding (self .processing_class )
134136
135- self .policy .generation_config .eos_token_id = (
137+ self .model .generation_config .eos_token_id = (
136138 None # disable `pad_token_id` and `eos_token_id` because we just want to
137139 )
138- self .policy .generation_config .pad_token_id = None # generate tokens without truncation / padding
140+ self .model .generation_config .pad_token_id = None # generate tokens without truncation / padding
139141
140142 # peft support
141143 if not is_peft_available () and peft_config is not None :
@@ -144,24 +146,24 @@ def __init__(
144146 )
145147 elif is_peft_available () and peft_config is not None :
146148 # if model is a peft model and we have a peft_confg, we merge and unload it first
147- if isinstance (self .policy , PeftModel ):
148- self .policy = self .policy .merge_and_unload ()
149+ if isinstance (self .model , PeftModel ):
150+ self .model = self .model .merge_and_unload ()
149151
150152 # get peft model with the given config
151- self .policy = get_peft_model (self .policy , peft_config )
152- if args .bf16 and getattr (self .policy , "is_loaded_in_4bit" , False ):
153- peft_module_casting_to_bf16 (self .policy )
153+ self .model = get_peft_model (self .model , peft_config )
154+ if args .bf16 and getattr (self .model , "is_loaded_in_4bit" , False ):
155+ peft_module_casting_to_bf16 (self .model )
154156
155- self .is_peft_model = is_peft_available () and isinstance (self .policy , PeftModel )
157+ self .is_peft_model = is_peft_available () and isinstance (self .model , PeftModel )
156158 self .model_adapter_name = args .model_adapter_name
157159 self .ref_adapter_name = args .ref_adapter_name
158160
159- if ref_policy :
160- self .ref_policy = ref_policy
161+ if ref_model :
162+ self .ref_model = ref_model
161163 elif self .is_peft_model :
162- self .ref_policy = None
164+ self .ref_model = None
163165 else :
164- self .ref_policy = create_reference_model (self .policy )
166+ self .ref_model = create_reference_model (self .model )
165167
166168 self .reward_model = reward_model
167169 self .train_dataset = train_dataset
@@ -211,13 +213,13 @@ def __init__(
211213 #########
212214 # setup model, optimizer, and others
213215 #########
214- for module in [self .policy , self .ref_policy , self .value_model , self .reward_model ]:
216+ for module in [self .model , self .ref_model , self .value_model , self .reward_model ]:
215217 if module is not None :
216218 disable_dropout_in_model (module )
217219 if args .stop_token and args .stop_token == "eos" :
218220 args .stop_token_id = processing_class .eos_token_id
219- self .model = PolicyAndValueWrapper (self .policy , self .value_model )
220- self .model .config = self .policy .config # needed for pushing to hub
221+ self .policy_and_value = PolicyAndValueWrapper (self .model , self .value_model )
222+ self .policy_and_value .config = self .model .config # needed for pushing to hub
221223 self .create_optimizer_and_scheduler (
222224 num_training_steps = args .num_total_batches
223225 ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
@@ -228,7 +230,7 @@ def __init__(
228230 default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks (self .args .report_to )
229231 self .callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
230232 self .callback_handler = CallbackHandler (
231- self .callbacks , self .model , self .processing_class , self .optimizer , self .lr_scheduler
233+ self .callbacks , self .policy_and_value , self .processing_class , self .optimizer , self .lr_scheduler
232234 )
233235 self .add_callback (PrinterCallback if self .args .disable_tqdm else DEFAULT_PROGRESS_CALLBACK )
234236 self .control = TrainerControl ()
@@ -251,8 +253,8 @@ def __init__(
251253 os .makedirs (self .args .output_dir , exist_ok = True )
252254
253255 # Add tags for models that have been loaded with the correct transformers version
254- if hasattr (self .model , "add_model_tags" ):
255- self .model .add_model_tags (self ._tag_names )
256+ if hasattr (self .policy_and_value , "add_model_tags" ):
257+ self .policy_and_value .add_model_tags (self ._tag_names )
256258
257259 #########
258260 ### setup dataloader
@@ -267,7 +269,9 @@ def __init__(
267269 # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
268270 # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
269271 torch .manual_seed (args .seed )
270- self .model , self .optimizer , self .dataloader = accelerator .prepare (self .model , self .optimizer , self .dataloader )
272+ self .policy_and_value , self .optimizer , self .dataloader = accelerator .prepare (
273+ self .policy_and_value , self .optimizer , self .dataloader
274+ )
271275 torch .manual_seed (self .local_seed ) # reset the local seed again
272276
273277 self .eval_dataloader = DataLoader (
@@ -283,19 +287,19 @@ def __init__(
283287 self .reward_model , args .per_device_train_batch_size , args .fp16 , args .bf16
284288 )
285289
286- if self .ref_policy is None :
290+ if self .ref_model is None :
287291 if not self .is_peft_model :
288292 raise ValueError ("No reference model and model is not a Peft model." )
289293 else :
290- self .ref_policy = prepare_deepspeed (
291- self .ref_policy , args .per_device_train_batch_size , args .fp16 , args .bf16
294+ self .ref_model = prepare_deepspeed (
295+ self .ref_model , args .per_device_train_batch_size , args .fp16 , args .bf16
292296 )
293297 else :
294- if self .ref_policy is None :
298+ if self .ref_model is None :
295299 if not self .is_peft_model :
296300 raise ValueError ("No reference model and model is not a Peft model." )
297301 else :
298- self .ref_policy = self .ref_policy .to (self .accelerator .device )
302+ self .ref_model = self .ref_model .to (self .accelerator .device )
299303 self .reward_model = self .reward_model .to (self .accelerator .device )
300304
301305 def get_train_dataloader (self ) -> DataLoader :
@@ -308,25 +312,25 @@ def get_eval_dataloader(self) -> DataLoader:
308312 def null_ref_context (self ):
309313 """Context manager for handling null reference model (that is, peft adapter manipulation)."""
310314 with self .accelerator .unwrap_model (
311- self .model .policy
315+ self .policy_and_value .policy
312316 ).disable_adapter () if self .is_peft_model and not self .ref_adapter_name else nullcontext ():
313317 if self .ref_adapter_name :
314- self .model .policy .set_adapter (self .ref_adapter_name )
318+ self .policy_and_value .policy .set_adapter (self .ref_adapter_name )
315319 yield
316320 if self .ref_adapter_name :
317- self .model .policy .set_adapter (self .model_adapter_name or "default" )
321+ self .policy_and_value .policy .set_adapter (self .model_adapter_name or "default" )
318322
319323 def save_model (self , output_dir : Optional [str ] = None , _internal_call : bool = False ):
320- backup_model = self .model
321- self .model = self .model .policy # save only the policy
324+ backup_model = self .policy_and_value
325+ self .policy_and_value = self .policy_and_value .policy # save only the policy
322326
323327 if self .is_deepspeed_enabled :
324328 backup_deepspeed = self .deepspeed
325- self .deepspeed = self .model
329+ self .deepspeed = self .policy_and_value
326330
327331 super ().save_model (output_dir , _internal_call )
328332
329- self .model = backup_model
333+ self .policy_and_value = backup_model
330334
331335 if self .is_deepspeed_enabled :
332336 self .deepspeed = backup_deepspeed
@@ -335,8 +339,8 @@ def train(self):
335339 args = self .args
336340 accelerator = self .accelerator
337341 optimizer = self .optimizer
338- model = self .model
339- ref_policy = self .ref_policy
342+ model = self .policy_and_value
343+ ref_policy = self .ref_model
340344 reward_model = self .reward_model
341345 processing_class = self .processing_class
342346 dataloader = self .dataloader
@@ -392,8 +396,8 @@ def repeat_generator():
392396
393397 # backward compatibility
394398 if self .is_deepspeed_enabled :
395- self .deepspeed = self .model
396- self .model_wrapped = self .model
399+ self .deepspeed = self .policy_and_value
400+ self .model_wrapped = self .policy_and_value
397401
398402 for update in range (1 , args .num_total_batches + 1 ):
399403 self .state .episode += 1 * args .batch_size
@@ -680,7 +684,7 @@ def generate_completions(self, sampling: bool = False):
680684 )
681685
682686 table = defaultdict (list )
683- with unwrap_model_for_generation (self .model , self .accelerator ) as unwrapped_model :
687+ with unwrap_model_for_generation (self .policy_and_value , self .accelerator ) as unwrapped_model :
684688 for batch in self .eval_dataloader :
685689 query = batch ["input_ids" ]
686690 with torch .no_grad ():
@@ -743,16 +747,18 @@ def create_model_card(
743747 if not self .is_world_process_zero ():
744748 return
745749
746- if hasattr (self .model .config , "_name_or_path" ) and not os .path .isdir (self .model .config ._name_or_path ):
747- base_model = self .model .config ._name_or_path
750+ if hasattr (self .policy_and_value .config , "_name_or_path" ) and not os .path .isdir (
751+ self .policy_and_value .config ._name_or_path
752+ ):
753+ base_model = self .policy_and_value .config ._name_or_path
748754 else :
749755 base_model = None
750756
751757 tags = tags or []
752758 if isinstance (tags , str ):
753759 tags = [tags ]
754760
755- if hasattr (self .model .config , "unsloth_version" ):
761+ if hasattr (self .policy_and_value .config , "unsloth_version" ):
756762 tags .append ("unsloth" )
757763
758764 citation = textwrap .dedent ("""\
0 commit comments