Skip to content

Commit 52547c5

Browse files
Remove extra_fields dead code [3/N]: Remove extra_fields from GRPOTrainer._generate_single_turn return value (#5264)
1 parent 5879856 commit 52547c5

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,7 @@ def _generate_single_turn(self, prompts: list):
13511351
]
13521352
logprobs = None # not used in this case
13531353

1354-
return prompt_ids, completion_ids, logprobs, {}
1354+
return prompt_ids, completion_ids, logprobs
13551355

13561356
def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs):
13571357
# Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt
@@ -1456,7 +1456,7 @@ async def _run_async_tools(async_coros):
14561456
break # all overlong, exit tool loop
14571457

14581458
# Generate new completions after tool execution
1459-
prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn(
1459+
prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs = self._generate_single_turn(
14601460
prompt_completion_tools
14611461
)
14621462

@@ -1549,7 +1549,8 @@ def _generate(self, prompts: list):
15491549
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
15501550
prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"]
15511551
else:
1552-
prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)
1552+
prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts)
1553+
extra_fields = {}
15531554

15541555
# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
15551556
if is_conversational({"prompt": prompts[0]}):

0 commit comments

Comments
 (0)