Skip to content

Commit c7baf18

Browse files
hjh0119gemini-code-assist[bot]
authored andcommitted
[bugfix] grpo length context compatible with latest set_default_max_tokens (#5154)
* compatible with latest set_default_max_tokens * Update swift/trainers/rlhf_trainer/grpo_trainer.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 340bf41 commit c7baf18

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,12 +1504,11 @@ def multi_turn_completion_length_context(self):
15041504
original_fn = self.engine.set_default_max_tokens
15051505
original_max_len = self.engine.max_model_len
15061506

1507-
def set_default_max_tokens(_self, request_config: RequestConfig, inputs: InputsType) -> None:
1507+
def set_default_max_tokens(_self, request_config: RequestConfig, inputs: Dict[str, Any]) -> None:
15081508
# Calculate required context window
15091509
original_max_len = _self.max_model_len or 8192
1510-
if isinstance(inputs, dict):
1511-
inputs = [inputs]
1512-
prompt_tokens = max(_self._get_num_tokens(inp) for inp in inputs)
1510+
assert isinstance(inputs, dict)
1511+
prompt_tokens = _self._get_num_tokens(inputs)
15131512

15141513
if not hasattr(_self, 'set_grpo_max_model_len'):
15151514
# set max model len in first round

0 commit comments

Comments
 (0)