Skip to content

Commit ee3cbe1

Browse files
authored
💾 Deprecate config in favor of args in PPOTrainer (huggingface#2384)
1 parent 17e8060 commit ee3cbe1

File tree

5 files changed

+11
-10
lines changed

5 files changed

+11
-10
lines changed

examples/research_projects/tools/python_interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def solution():
154154
optimize_cuda_cache=True,
155155
)
156156

157-
ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
157+
ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
158158
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)
159159

160160
# text env

examples/research_projects/tools/triviaqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class ScriptArguments:
105105
seed=script_args.seed,
106106
optimize_cuda_cache=True,
107107
)
108-
ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer)
108+
ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer)
109109
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
110110
local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
111111
dataset = dataset.shuffle(local_seed)

examples/scripts/ppo/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def tokenize(element):
152152
# Training
153153
################
154154
trainer = PPOTrainer(
155-
config=training_args,
155+
args=training_args,
156156
processing_class=tokenizer,
157157
policy=policy,
158158
ref_policy=ref_policy,

examples/scripts/ppo/ppo_tldr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def tokenize(element):
163163
# Training
164164
################
165165
trainer = PPOTrainer(
166-
config=training_args,
166+
args=training_args,
167167
processing_class=tokenizer,
168168
policy=policy,
169169
ref_policy=ref_policy,

trl/trainer/ppo_trainer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,21 @@
5151
from ..core import masked_mean, masked_whiten
5252
from ..models import create_reference_model
5353
from ..models.utils import unwrap_model_for_generation
54-
from ..trainer.utils import (
54+
from .ppo_config import PPOConfig
55+
from .utils import (
5556
OnlineTrainerState,
5657
batch_generation,
5758
disable_dropout_in_model,
5859
exact_div,
5960
first_true_indices,
6061
forward,
62+
generate_model_card,
6163
get_reward,
64+
peft_module_casting_to_bf16,
6265
prepare_deepspeed,
6366
print_rich_table,
6467
truncate_response,
6568
)
66-
from .ppo_config import PPOConfig
67-
from .utils import generate_model_card, peft_module_casting_to_bf16
6869

6970

7071
if is_peft_available():
@@ -97,10 +98,11 @@ def forward(self, **kwargs):
9798
class PPOTrainer(Trainer):
9899
_tag_names = ["trl", "ppo"]
99100

101+
@deprecate_kwarg("config", new_name="args", version="0.15.0", raise_if_both_names=True)
100102
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
101103
def __init__(
102104
self,
103-
config: PPOConfig,
105+
args: PPOConfig,
104106
processing_class: Optional[
105107
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
106108
],
@@ -122,8 +124,7 @@ def __init__(
122124
"same as `policy`, you must make a copy of it, or `None` if you use peft."
123125
)
124126

125-
self.args = config
126-
args = config
127+
self.args = args
127128
self.processing_class = processing_class
128129
self.policy = policy
129130

0 commit comments

Comments
 (0)