Skip to content

Commit c7fd1bd

Browse files
committed
Merge branch 'main' into release/3.5
2 parents 21ccf17 + 0a14ac4 commit c7fd1bd

File tree

8 files changed

+23
-10
lines changed

8 files changed

+23
-10
lines changed

examples/train/grpo/qwen2_5_omni/grpo.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pip install transformers math_verify trl -U
33

44
MAX_PIXELS=1003520 \
55
NPROC_PER_NODE=4 \
6+
ENABLE_AUDIO_OUTPUT=1 \
67
CUDA_VISIBLE_DEVICES=0,1,2,3 \
78
swift rlhf \
89
--rlhf_type grpo \

examples/train/multimodal/omni/sft.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pip install transformers -U
55
nproc_per_node=4
66

77
CUDA_VISIBLE_DEVICES=0,1,2,3 \
8+
ENABLE_AUDIO_OUTPUT=1 \
89
NPROC_PER_NODE=$nproc_per_node \
910
VIDEO_MAX_PIXELS=50176 \
1011
FPS_MAX_FRAMES=12 \

examples/train/packing/qwen2_5_omni.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
pip install transformers -U
66

77
NPROC_PER_NODE=4 \
8+
ENABLE_AUDIO_OUTPUT=1 \
89
CUDA_VISIBLE_DEVICES=0,1,2,3 \
910
VIDEO_MAX_PIXELS=50176 \
1011
FPS_MAX_FRAMES=12 \

swift/llm/argument/rlhf_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,10 @@ def _init_external_vllm(self):
200200
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
201201
if is_master():
202202
self.vllm_client = VLLMClient(
203-
self.vllm_server_host, self.vllm_server_port, connection_timeout=self.vllm_server_timeout)
203+
base_url=self.vllm_server_base_url,
204+
host=self.vllm_server_host,
205+
server_port=self.vllm_server_port,
206+
connection_timeout=self.vllm_server_timeout)
204207
self.vllm_client.init_communicator()
205208

206209
def _set_default(self):

swift/llm/train/sft.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,12 @@ def _save_trainer_state(self, trainer):
142142
training_args = trainer.args
143143
state = trainer.state
144144
if hasattr(state, 'last_model_checkpoint'):
145-
if is_master() and self.args.create_checkpoint_symlink:
145+
if self.args.create_checkpoint_symlink:
146146
last_checkpoint = os.path.join(self.args.output_dir, 'last')
147147
best_checkpoint = os.path.join(self.args.output_dir, 'best')
148-
os.symlink(state.last_model_checkpoint, last_checkpoint)
149-
os.symlink(state.best_model_checkpoint, best_checkpoint)
148+
if is_master():
149+
os.symlink(state.last_model_checkpoint, last_checkpoint)
150+
os.symlink(state.best_model_checkpoint, best_checkpoint)
150151
state.last_model_checkpoint = last_checkpoint
151152
state.best_model_checkpoint = best_checkpoint
152153
else:

swift/trainers/mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def clip_grad_norm_(self, parameters, *args, **kwargs):
316316
def _prepare_gradient_checkpointing(self, model) -> None:
317317
from swift.llm import HfConfigFactory, get_model_arch, deep_getattr, dynamic_gradient_checkpointing
318318
args = self.args
319+
HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
319320
if args.gradient_checkpointing or args.vit_gradient_checkpointing:
320-
HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
321321
dynamic_gradient_checkpointing(model, args.vit_gradient_checkpointing)
322322
if args.gradient_checkpointing:
323323
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(self,
266266
self.model_accepts_loss_kwargs = False
267267
self.padding_free = self.template.padding_free
268268
self.template.padding_free = False
269+
self.template._packing = False
269270
for i, reward_func in enumerate(self.reward_funcs):
270271
if isinstance(reward_func, PreTrainedModel):
271272
if self.is_deepspeed_enabled:
@@ -1196,10 +1197,12 @@ def _padding_free_output_hook(module, args, kwargs, result):
11961197
result.last_hidden_state = torch.stack(unpacked_logits, dim=0)
11971198
return result
11981199

1199-
llm_model = get_llm_model(model)
1200-
1201-
base_model = llm_model.model
12021200
if self.padding_free:
1201+
llm_model = get_llm_model(model)
1202+
if hasattr(llm_model, 'thinker'):
1203+
base_model = llm_model.thinker.model
1204+
else:
1205+
base_model = llm_model.model
12031206
remove_handle1 = base_model.register_forward_pre_hook(
12041207
_padding_free_input_hook, with_kwargs=True, prepend=True)
12051208
remove_handle2 = base_model.register_forward_hook(_padding_free_output_hook, with_kwargs=True, prepend=True)

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,10 @@ def pre_forward_split_hook(_self, args, kwargs):
667667

668668
llm_model = get_llm_model(model)
669669

670-
base_model = llm_model.model
670+
if hasattr(llm_model, 'thinker'):
671+
base_model = llm_model.thinker.model
672+
else:
673+
base_model = llm_model.model
671674
if hasattr(base_model, 'language_model'):
672675
self.causal_mask_func = base_model.language_model._update_causal_mask
673676
else:
@@ -845,7 +848,7 @@ def rlhf_loss_scale_sp_func(_, *args, **kwargs):
845848
compute_acc_origin = metric.compute_acc
846849

847850
def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]:
848-
851+
_, _, labels, _, _, _ = self.pad_and_split_inputs(None, None, labels, None, None, None)
849852
# Gather preds and labels across the sp group
850853
if isinstance(preds, np.ndarray):
851854
preds = torch.from_numpy(preds).to(get_current_device())

0 commit comments

Comments
 (0)