Skip to content

Commit 9dfa63a

Browse files
authored
Fix create checkpoint symlink & grpo omni (#4468)
1 parent 1050701 commit 9dfa63a

File tree

5 files changed

+9
-6
lines changed

5 files changed

+9
-6
lines changed

examples/train/grpo/qwen2_5_omni/grpo.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pip install transformers math_verify trl -U
33

44
MAX_PIXELS=1003520 \
55
NPROC_PER_NODE=4 \
6+
ENABLE_AUDIO_OUTPUT=1 \
67
CUDA_VISIBLE_DEVICES=0,1,2,3 \
78
swift rlhf \
89
--rlhf_type grpo \

examples/train/multimodal/omni/sft.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pip install transformers -U
55
nproc_per_node=4
66

77
CUDA_VISIBLE_DEVICES=0,1,2,3 \
8+
ENABLE_AUDIO_OUTPUT=1 \
89
NPROC_PER_NODE=$nproc_per_node \
910
VIDEO_MAX_PIXELS=50176 \
1011
FPS_MAX_FRAMES=12 \

examples/train/packing/qwen2_5_omni.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
pip install transformers -U
66

77
NPROC_PER_NODE=4 \
8+
ENABLE_AUDIO_OUTPUT=1 \
89
CUDA_VISIBLE_DEVICES=0,1,2,3 \
910
VIDEO_MAX_PIXELS=50176 \
1011
FPS_MAX_FRAMES=12 \

swift/llm/train/sft.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,12 @@ def _save_trainer_state(self, trainer):
142142
training_args = trainer.args
143143
state = trainer.state
144144
if hasattr(state, 'last_model_checkpoint'):
145-
if is_master() and self.args.create_checkpoint_symlink:
145+
if self.args.create_checkpoint_symlink:
146146
last_checkpoint = os.path.join(self.args.output_dir, 'last')
147147
best_checkpoint = os.path.join(self.args.output_dir, 'best')
148-
os.symlink(state.last_model_checkpoint, last_checkpoint)
149-
os.symlink(state.best_model_checkpoint, best_checkpoint)
148+
if is_master():
149+
os.symlink(state.last_model_checkpoint, last_checkpoint)
150+
os.symlink(state.best_model_checkpoint, best_checkpoint)
150151
state.last_model_checkpoint = last_checkpoint
151152
state.best_model_checkpoint = best_checkpoint
152153
else:

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,10 +1197,9 @@ def _padding_free_output_hook(module, args, kwargs, result):
11971197
result.last_hidden_state = torch.stack(unpacked_logits, dim=0)
11981198
return result
11991199

1200-
llm_model = get_llm_model(model)
1201-
1202-
base_model = llm_model.model
12031200
if self.padding_free:
1201+
llm_model = get_llm_model(model)
1202+
base_model = llm_model.model
12041203
remove_handle1 = base_model.register_forward_pre_hook(
12051204
_padding_free_input_hook, with_kwargs=True, prepend=True)
12061205
remove_handle2 = base_model.register_forward_hook(_padding_free_output_hook, with_kwargs=True, prepend=True)

0 commit comments

Comments
 (0)