Skip to content

Commit c3f2b1c

Browse files
authored
Add num_completions_to_print for trl and grpo (axolotl-ai-cloud#2604)
1 parent 6ba5c0e commit c3f2b1c

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/axolotl/core/trainers/grpo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def set_training_args_kwargs(cls, cfg):
6363

6464
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
6565
grpo_args_kwargs["log_completions"] = trl.log_completions
66+
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
6667

6768
if trl.reward_weights:
6869
grpo_args_kwargs["reward_weights"] = trl.reward_weights

src/axolotl/utils/schemas/trl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ class TRLConfig(BaseModel):
6767
default=False,
6868
json_schema_extra={"description": "Whether to log completions"},
6969
)
70+
num_completions_to_print: int | None = Field(
71+
default=None,
72+
json_schema_extra={
73+
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
74+
},
75+
)
7076
sync_ref_model: bool | None = Field(
7177
default=False,
7278
json_schema_extra={

0 commit comments

Comments
 (0)