Skip to content

Commit 16fa13c

Browse files
authored
👮 Deprecate policy in favor of model in PPOTrainer (huggingface#2386)
1 parent 453db5c commit 16fa13c

File tree

4 files changed

+60
-54
lines changed

4 files changed

+60
-54
lines changed

docs/source/detoxifying_a_lm.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ and the optimizer will take care of computing the gradients in `bfloat16` precis
105105
</div>
106106

107107
```python
108-
ref_policy = create_reference_model(model, num_shared_layers=6)
109-
trainer = PPOTrainer(..., ref_policy=ref_policy)
108+
ref_model = create_reference_model(model, num_shared_layers=6)
109+
trainer = PPOTrainer(..., ref_model=ref_model)
110110
```
111111

112112
In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).

examples/scripts/ppo/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ def tokenize(element):
154154
trainer = PPOTrainer(
155155
args=training_args,
156156
processing_class=tokenizer,
157-
policy=policy,
158-
ref_policy=ref_policy,
157+
model=policy,
158+
ref_model=ref_policy,
159159
reward_model=reward_model,
160160
value_model=value_model,
161161
train_dataset=train_dataset,

examples/scripts/ppo/ppo_tldr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def tokenize(element):
165165
trainer = PPOTrainer(
166166
args=training_args,
167167
processing_class=tokenizer,
168-
policy=policy,
169-
ref_policy=ref_policy,
168+
model=policy,
169+
ref_model=ref_policy,
170170
reward_model=reward_model,
171171
value_model=value_model,
172172
train_dataset=train_dataset,

trl/trainer/ppo_trainer.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def __init__(self, policy, value_model) -> None:
8888
self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
8989

9090
def forward(self, **kwargs):
91-
output = self.critic_backbone(
92-
**kwargs,
93-
)
91+
output = self.critic_backbone(**kwargs)
9492
logits = self.value_model.score(output.hidden_states[-1])
9593
return self.policy(**kwargs), logits
9694

@@ -100,14 +98,18 @@ class PPOTrainer(Trainer):
10098

10199
@deprecate_kwarg("config", new_name="args", version="0.15.0", raise_if_both_names=True)
102100
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
101+
@deprecate_kwarg("policy", "0.15.0", "model", warn_if_greater_or_equal_version=True, raise_if_both_names=True)
102+
@deprecate_kwarg(
103+
"ref_policy", "0.15.0", "ref_model", warn_if_greater_or_equal_version=True, raise_if_both_names=True
104+
)
103105
def __init__(
104106
self,
105107
args: PPOConfig,
106108
processing_class: Optional[
107109
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
108110
],
109-
policy: nn.Module,
110-
ref_policy: Optional[nn.Module],
111+
model: nn.Module,
112+
ref_model: Optional[nn.Module],
111113
reward_model: nn.Module,
112114
train_dataset: Dataset,
113115
value_model: Optional[nn.Module] = None,
@@ -118,24 +120,24 @@ def __init__(
118120
callbacks: Optional[List[TrainerCallback]] = None,
119121
peft_config: Optional["PeftConfig"] = None,
120122
) -> None:
121-
if ref_policy is policy:
123+
if ref_model is model:
122124
raise ValueError(
123-
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
124-
"same as `policy`, you must make a copy of it, or `None` if you use peft."
125+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
126+
"same as `model`, you must make a copy of it, or `None` if you use peft."
125127
)
126128

127129
self.args = args
128130
self.processing_class = processing_class
129-
self.policy = policy
131+
self.model = model
130132

131133
# Define the collator if not provided
132134
if data_collator is None:
133135
data_collator = DataCollatorWithPadding(self.processing_class)
134136

135-
self.policy.generation_config.eos_token_id = (
137+
self.model.generation_config.eos_token_id = (
136138
None # disable `pad_token_id` and `eos_token_id` because we just want to
137139
)
138-
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
140+
self.model.generation_config.pad_token_id = None # generate tokens without truncation / padding
139141

140142
# peft support
141143
if not is_peft_available() and peft_config is not None:
@@ -144,24 +146,24 @@ def __init__(
144146
)
145147
elif is_peft_available() and peft_config is not None:
146148
# if model is a peft model and we have a peft_confg, we merge and unload it first
147-
if isinstance(self.policy, PeftModel):
148-
self.policy = self.policy.merge_and_unload()
149+
if isinstance(self.model, PeftModel):
150+
self.model = self.model.merge_and_unload()
149151

150152
# get peft model with the given config
151-
self.policy = get_peft_model(self.policy, peft_config)
152-
if args.bf16 and getattr(self.policy, "is_loaded_in_4bit", False):
153-
peft_module_casting_to_bf16(self.policy)
153+
self.model = get_peft_model(self.model, peft_config)
154+
if args.bf16 and getattr(self.model, "is_loaded_in_4bit", False):
155+
peft_module_casting_to_bf16(self.model)
154156

155-
self.is_peft_model = is_peft_available() and isinstance(self.policy, PeftModel)
157+
self.is_peft_model = is_peft_available() and isinstance(self.model, PeftModel)
156158
self.model_adapter_name = args.model_adapter_name
157159
self.ref_adapter_name = args.ref_adapter_name
158160

159-
if ref_policy:
160-
self.ref_policy = ref_policy
161+
if ref_model:
162+
self.ref_model = ref_model
161163
elif self.is_peft_model:
162-
self.ref_policy = None
164+
self.ref_model = None
163165
else:
164-
self.ref_policy = create_reference_model(self.policy)
166+
self.ref_model = create_reference_model(self.model)
165167

166168
self.reward_model = reward_model
167169
self.train_dataset = train_dataset
@@ -211,13 +213,13 @@ def __init__(
211213
#########
212214
# setup model, optimizer, and others
213215
#########
214-
for module in [self.policy, self.ref_policy, self.value_model, self.reward_model]:
216+
for module in [self.model, self.ref_model, self.value_model, self.reward_model]:
215217
if module is not None:
216218
disable_dropout_in_model(module)
217219
if args.stop_token and args.stop_token == "eos":
218220
args.stop_token_id = processing_class.eos_token_id
219-
self.model = PolicyAndValueWrapper(self.policy, self.value_model)
220-
self.model.config = self.policy.config # needed for pushing to hub
221+
self.policy_and_value = PolicyAndValueWrapper(self.model, self.value_model)
222+
self.policy_and_value.config = self.model.config # needed for pushing to hub
221223
self.create_optimizer_and_scheduler(
222224
num_training_steps=args.num_total_batches
223225
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
@@ -228,7 +230,7 @@ def __init__(
228230
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
229231
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
230232
self.callback_handler = CallbackHandler(
231-
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
233+
self.callbacks, self.policy_and_value, self.processing_class, self.optimizer, self.lr_scheduler
232234
)
233235
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
234236
self.control = TrainerControl()
@@ -251,8 +253,8 @@ def __init__(
251253
os.makedirs(self.args.output_dir, exist_ok=True)
252254

253255
# Add tags for models that have been loaded with the correct transformers version
254-
if hasattr(self.model, "add_model_tags"):
255-
self.model.add_model_tags(self._tag_names)
256+
if hasattr(self.policy_and_value, "add_model_tags"):
257+
self.policy_and_value.add_model_tags(self._tag_names)
256258

257259
#########
258260
### setup dataloader
@@ -267,7 +269,9 @@ def __init__(
267269
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
268270
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
269271
torch.manual_seed(args.seed)
270-
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
272+
self.policy_and_value, self.optimizer, self.dataloader = accelerator.prepare(
273+
self.policy_and_value, self.optimizer, self.dataloader
274+
)
271275
torch.manual_seed(self.local_seed) # reset the local seed again
272276

273277
self.eval_dataloader = DataLoader(
@@ -283,19 +287,19 @@ def __init__(
283287
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
284288
)
285289

286-
if self.ref_policy is None:
290+
if self.ref_model is None:
287291
if not self.is_peft_model:
288292
raise ValueError("No reference model and model is not a Peft model.")
289293
else:
290-
self.ref_policy = prepare_deepspeed(
291-
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
294+
self.ref_model = prepare_deepspeed(
295+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
292296
)
293297
else:
294-
if self.ref_policy is None:
298+
if self.ref_model is None:
295299
if not self.is_peft_model:
296300
raise ValueError("No reference model and model is not a Peft model.")
297301
else:
298-
self.ref_policy = self.ref_policy.to(self.accelerator.device)
302+
self.ref_model = self.ref_model.to(self.accelerator.device)
299303
self.reward_model = self.reward_model.to(self.accelerator.device)
300304

301305
def get_train_dataloader(self) -> DataLoader:
@@ -308,25 +312,25 @@ def get_eval_dataloader(self) -> DataLoader:
308312
def null_ref_context(self):
309313
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
310314
with self.accelerator.unwrap_model(
311-
self.model.policy
315+
self.policy_and_value.policy
312316
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
313317
if self.ref_adapter_name:
314-
self.model.policy.set_adapter(self.ref_adapter_name)
318+
self.policy_and_value.policy.set_adapter(self.ref_adapter_name)
315319
yield
316320
if self.ref_adapter_name:
317-
self.model.policy.set_adapter(self.model_adapter_name or "default")
321+
self.policy_and_value.policy.set_adapter(self.model_adapter_name or "default")
318322

319323
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
320-
backup_model = self.model
321-
self.model = self.model.policy # save only the policy
324+
backup_model = self.policy_and_value
325+
self.policy_and_value = self.policy_and_value.policy # save only the policy
322326

323327
if self.is_deepspeed_enabled:
324328
backup_deepspeed = self.deepspeed
325-
self.deepspeed = self.model
329+
self.deepspeed = self.policy_and_value
326330

327331
super().save_model(output_dir, _internal_call)
328332

329-
self.model = backup_model
333+
self.policy_and_value = backup_model
330334

331335
if self.is_deepspeed_enabled:
332336
self.deepspeed = backup_deepspeed
@@ -335,8 +339,8 @@ def train(self):
335339
args = self.args
336340
accelerator = self.accelerator
337341
optimizer = self.optimizer
338-
model = self.model
339-
ref_policy = self.ref_policy
342+
model = self.policy_and_value
343+
ref_policy = self.ref_model
340344
reward_model = self.reward_model
341345
processing_class = self.processing_class
342346
dataloader = self.dataloader
@@ -392,8 +396,8 @@ def repeat_generator():
392396

393397
# backward compatibility
394398
if self.is_deepspeed_enabled:
395-
self.deepspeed = self.model
396-
self.model_wrapped = self.model
399+
self.deepspeed = self.policy_and_value
400+
self.model_wrapped = self.policy_and_value
397401

398402
for update in range(1, args.num_total_batches + 1):
399403
self.state.episode += 1 * args.batch_size
@@ -680,7 +684,7 @@ def generate_completions(self, sampling: bool = False):
680684
)
681685

682686
table = defaultdict(list)
683-
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
687+
with unwrap_model_for_generation(self.policy_and_value, self.accelerator) as unwrapped_model:
684688
for batch in self.eval_dataloader:
685689
query = batch["input_ids"]
686690
with torch.no_grad():
@@ -743,16 +747,18 @@ def create_model_card(
743747
if not self.is_world_process_zero():
744748
return
745749

746-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
747-
base_model = self.model.config._name_or_path
750+
if hasattr(self.policy_and_value.config, "_name_or_path") and not os.path.isdir(
751+
self.policy_and_value.config._name_or_path
752+
):
753+
base_model = self.policy_and_value.config._name_or_path
748754
else:
749755
base_model = None
750756

751757
tags = tags or []
752758
if isinstance(tags, str):
753759
tags = [tags]
754760

755-
if hasattr(self.model.config, "unsloth_version"):
761+
if hasattr(self.policy_and_value.config, "unsloth_version"):
756762
tags.append("unsloth")
757763

758764
citation = textwrap.dedent("""\

0 commit comments

Comments
 (0)