Skip to content

Commit 0a14ac4

Browse files
tastelikefeettastelikefeet
andauthored
Fix omni grpo (#4469)
Co-authored-by: tastelikefeet <[email protected]>
1 parent 9dfa63a commit 0a14ac4

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,10 @@ def _padding_free_output_hook(module, args, kwargs, result):
11991199

12001200
if self.padding_free:
12011201
llm_model = get_llm_model(model)
1202-
base_model = llm_model.model
1202+
if hasattr(llm_model, 'thinker'):
1203+
base_model = llm_model.thinker.model
1204+
else:
1205+
base_model = llm_model.model
12031206
remove_handle1 = base_model.register_forward_pre_hook(
12041207
_padding_free_input_hook, with_kwargs=True, prepend=True)
12051208
remove_handle2 = base_model.register_forward_hook(_padding_free_output_hook, with_kwargs=True, prepend=True)

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,10 @@ def pre_forward_split_hook(_self, args, kwargs):
667667

668668
llm_model = get_llm_model(model)
669669

670-
base_model = llm_model.model
670+
if hasattr(llm_model, 'thinker'):
671+
base_model = llm_model.thinker.model
672+
else:
673+
base_model = llm_model.model
671674
if hasattr(base_model, 'language_model'):
672675
self.causal_mask_func = base_model.language_model._update_causal_mask
673676
else:

0 commit comments

Comments
 (0)