Skip to content

Commit 6ee1317

Browse files
authored
[bugfix] fix max_shard_size transformers 5.x (#8209)
1 parent 42a0809 commit 6ee1317

File tree

4 files changed

+17
-35
lines changed

4 files changed

+17
-35
lines changed

docs/source/Instruction/Command-line-parameters.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,10 @@ gradient_checkpointing: true
219219
- router_aux_loss_coef: 用于moe模型训练时,设置 aux_loss 的权重,默认为`0.`
220220
- enable_dft_loss: 是否在SFT训练中使用[DFT](https://arxiv.org/abs/2508.05629) (Dynamic Fine-Tuning) loss,默认为False。
221221
- enable_channel_loss: 启用channel loss,默认为`False`。你需要在数据集中准备"channel"字段,ms-swift会根据该字段分组统计loss(若未准备"channel"字段,则归为默认`None` channel)。数据集格式参考[channel loss](../Customization/Custom-dataset.md#channel-loss)。channel loss兼容packing/padding_free/loss_scale等技术。
222+
- safe_serialization: 是否存储为safetensors,默认为True。
223+
- max_shard_size: 单存储文件最大大小,默认'5GB'。
222224
- logging_dir: tensorboard日志保存路径。默认为None,即设置为`f'{self.output_dir}/runs'`
223-
- 🔥predict_with_generate: 验证时使用生成式的方式,默认为False。
225+
- predict_with_generate: 验证时使用生成式的方式,默认为False。
224226
- metric_for_best_model: 默认为None,即当`predict_with_generate`设置为False时,设置为'loss',否则设置为'rouge-l'(在PPO训练时,不进行默认值设置;GRPO训练设置为'reward')。
225227
- greater_is_better: 默认为None,即当`metric_for_best_model`含'loss'时,设置为False,否则设置为True。
226228
- max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ This list inherits from the Transformers `Seq2SeqTrainingArguments`, with ms-swi
222222
- router_aux_loss_coef: Used in MoE model training to set the weight of auxiliary loss. Default is `0.`.
223223
- enable_dft_loss: Whether to use [DFT](https://arxiv.org/abs/2508.05629) (Dynamic Fine-Tuning) loss during SFT training. Default is `False`.
224224
- enable_channel_loss: Enable channel-based loss. Default is `False`. Requires a `"channel"` field in the dataset. ms-swift groups and computes loss by this field (samples without `"channel"` are grouped into the default `None` channel). Dataset format reference: [channel loss](../Customization/Custom-dataset.md#channel-loss). Channel loss is compatible with packing, padding_free, and loss_scale techniques.
225+
- safe_serialization: Whether to save the model in safetensors format. Default is True.
226+
- max_shard_size: Maximum size of a single storage file, default is '5GB'.
225227
- logging_dir: Directory for TensorBoard logs. Default is `None`, automatically set to `f'{self.output_dir}/runs'`.
226228
- predict_with_generate: Use generation during evaluation. Default is `False`.
227229
- metric_for_best_model: Default is `None`. If `predict_with_generate=False`, it's set to `'loss'`; otherwise `'rouge-l'` (in PPO training, no default; in GRPO, set to `'reward'`).

swift/trainers/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class TrainArgumentsMixin:
135135
router_aux_loss_coef: float = 0.
136136
enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629
137137
enable_channel_loss: bool = False
138+
safe_serialization: bool = True
139+
max_shard_size: str = '5GB'
138140

139141
weight_decay: float = 0.1
140142
adam_beta2: float = 0.95

swift/trainers/mixin.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,6 @@
5454
from .utils import (can_return_loss, dynamic_gradient_checkpointing, find_labels, get_function, get_resume_dir,
5555
is_instance_of_ms_model, patch_modelscope_hub_timeout, replace_index_file)
5656

57-
try:
58-
from trl import AutoModelForCausalLMWithValueHead
59-
except (ImportError, RuntimeError):
60-
AutoModelForCausalLMWithValueHead = None
61-
6257
logger = get_logger()
6358

6459

@@ -275,9 +270,7 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
275270
# model
276271
supported_classes = (SwiftModel, PreTrainedModel, PeftModel)
277272
supported_names = ('SentenceTransformer', )
278-
if AutoModelForCausalLMWithValueHead is not None:
279-
supported_classes = supported_classes + (AutoModelForCausalLMWithValueHead, )
280-
save_safetensors = getattr(self.args, 'save_safetensors', True)
273+
safe_serialization = self.args.safe_serialization
281274
use_flash_ckpt = self.args.use_flash_ckpt
282275

283276
if not isinstance(self.model, supported_classes) and self.model.__class__.__name__ not in supported_names:
@@ -286,7 +279,7 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
286279

287280
_unwrap_model = unwrap_model(self.model)
288281
if isinstance(_unwrap_model, supported_classes):
289-
save_kwargs = {'state_dict': state_dict}
282+
save_kwargs = {'state_dict': state_dict, 'max_shard_size': self.args.max_shard_size}
290283
if isinstance(_unwrap_model, PeftModel):
291284
save_kwargs['selected_adapters'] = ['default']
292285
if use_flash_ckpt:
@@ -296,33 +289,16 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
296289
save_function=self.flash_checkpointer.ckpt_agent.save,
297290
**save_kwargs)
298291
else:
299-
_unwrap_model.save_pretrained(output_dir, safe_serialization=save_safetensors, **save_kwargs)
292+
_unwrap_model.save_pretrained(output_dir, safe_serialization=safe_serialization, **save_kwargs)
300293
else:
301294
logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
302295
if use_flash_ckpt:
303296
self.flash_checkpointer.ckpt_agent.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
304297
else:
305-
if save_safetensors:
298+
if safe_serialization:
306299
safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
307300
else:
308301
torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
309-
elif AutoModelForCausalLMWithValueHead and isinstance(self.model, AutoModelForCausalLMWithValueHead):
310-
# save reward model
311-
state_dict = self.model.state_dict()
312-
decoder_state_dict, v_head_state_dict = {}, {}
313-
for name, param in state_dict.items():
314-
if name.startswith('v_head.'):
315-
v_head_state_dict[name] = param
316-
else:
317-
decoder_state_dict[name.replace('pretrained_model.', '', 1)] = param
318-
self.model.pretrained_model.save_pretrained(
319-
output_dir, state_dict=decoder_state_dict or None, safe_serialization=save_safetensors)
320-
if save_safetensors:
321-
from safetensors.torch import save_file
322-
save_file(
323-
v_head_state_dict, os.path.join(output_dir, 'value_head.safetensors'), metadata={'format': 'pt'})
324-
else:
325-
torch.save(v_head_state_dict, os.path.join(output_dir, 'value_head.bin'))
326302
elif is_instance_of_ms_model(self.model):
327303
if use_flash_ckpt:
328304
PreTrainedModel.save_pretrained(
@@ -334,13 +310,13 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
334310
else:
335311
# modelscope save_pretrained does not support safe_serialization
336312
PreTrainedModel.save_pretrained(
337-
self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
313+
self.model, output_dir, state_dict=state_dict, safe_serialization=safe_serialization)
338314
elif self.args.tuner_type in tuners_map:
339315
tuners_map[self.args.tuner_type].save_pretrained(
340-
self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
316+
self.model, output_dir, state_dict=state_dict, safe_serialization=safe_serialization)
341317
else:
342318
if self.model.__class__.__name__ != 'SentenceTransformer':
343-
save_kwargs = {'state_dict': state_dict}
319+
save_kwargs = {'state_dict': state_dict, 'max_shard_size': self.args.max_shard_size}
344320
if isinstance(self.model, PeftModel):
345321
save_kwargs['selected_adapters'] = ['default']
346322
if use_flash_ckpt:
@@ -350,7 +326,7 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
350326
save_function=self.flash_checkpointer.ckpt_agent.save,
351327
**save_kwargs)
352328
else:
353-
self.model.save_pretrained(output_dir, safe_serialization=save_safetensors, **save_kwargs)
329+
self.model.save_pretrained(output_dir, safe_serialization=safe_serialization, **save_kwargs)
354330
else:
355331

356332
@contextmanager
@@ -373,7 +349,7 @@ def save_context():
373349
safe_serialization=False,
374350
save_function=self.flash_checkpointer.ckpt_agent.save)
375351
else:
376-
self.model.save_pretrained(output_dir, safe_serialization=save_safetensors)
352+
self.model.save_pretrained(output_dir, safe_serialization=safe_serialization)
377353
# copy sentencetransformers files
378354
copy_files_by_pattern(
379355
self.model.model_dir, output_dir, '*.py', exclude_patterns=['model.safetensors.index.json'])
@@ -636,7 +612,7 @@ def _save_flash_checkpoint(self, model, trial, metrics=None):
636612
rng_states,
637613
os.path.join(output_dir, f'rng_state_{self.args.process_index}.pth'),
638614
)
639-
if self.args.save_safetensors:
615+
if self.args.safe_serialization:
640616
torch.save({'safe_serialization': True}, 'safe_serialization')
641617
replace_index_file(output_dir)
642618

0 commit comments

Comments
 (0)