Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ gradient_checkpointing: true
- 🔥neftune_noise_alpha: neftune添加的噪声系数。默认为0,通常可以设置为5、10、15。
- 🔥use_liger_kernel: 是否启用[Liger](https://github.com/linkedin/Liger-Kernel)内核加速训练并减少显存消耗。默认为False。示例shell参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/liger)。
- 注意:liger_kernel不支持device_map,请使用DDP/DeepSpeed进行多卡训练。liger_kernel目前只支持`task_type='causal_lm'`。
- use_tiled_mlp: 是否启用Tiled MLP进行内存高效的长序列训练。启用后,MLP层会被替换为分块实现,将序列分成多个shard进行计算以减少显存占用。默认为False。
- tiled_mlp_num_shards: Tiled MLP计算时将序列分成的shard数量。默认为None,即设置为4。较大的值可以减少显存但可能增加计算时间。
- average_tokens_across_devices: 是否在设备之间进行token数平均。如果设置为True,将使用all_reduce同步`num_tokens_in_batch`以进行精确的损失计算。默认为False。
- max_grad_norm: 梯度裁剪。默认为1.。
- 注意:日志中的grad_norm记录的是裁剪前的值。
Expand Down
2 changes: 2 additions & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ Other important parameters:
- 🔥neftune_noise_alpha: Noise magnitude for NEFTune. Default is 0. Common values: 5, 10, 15.
- 🔥use_liger_kernel: Whether to enable the [Liger](https://github.com/linkedin/Liger-Kernel) kernel to accelerate training and reduce GPU memory consumption. Defaults to False. Example shell script can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/liger).
- Note: Liger kernel does not support `device_map`. Use DDP or DeepSpeed for multi-GPU training. Currently, liger_kernel only supports `task_type='causal_lm'`.
- use_tiled_mlp: Whether to enable Tiled MLP for memory-efficient long sequence training. When enabled, MLP layers are replaced with a tiled implementation that processes sequences in chunks to reduce memory usage. Defaults to False.
- tiled_mlp_num_shards: Number of shards to split the sequence for tiled MLP computation. Defaults to None, which sets it to 4. Larger values reduce memory but may increase computation time.
- average_tokens_across_devices: Whether to average token counts across devices. If `True`, `num_tokens_in_batch` is synchronized via `all_reduce` for accurate loss computation. Default is `False`.
- max_grad_norm: Gradient clipping. Default is 1.
- Note: The logged `grad_norm` reflects the value **before** clipping.
Expand Down
25 changes: 25 additions & 0 deletions examples/train/tiled_mlp/fsdp2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"compute_environment": "LOCAL_MACHINE",
"debug": false,
"distributed_type": "FSDP",
"downcast_bf16": "no",
"fsdp_config": {
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_cpu_ram_efficient_loading": true,
"fsdp_reshard_after_forward": true,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_activation_checkpointing": true,
"fsdp_version": 2
},
"machine_rank": 0,
"main_training_function": "main",
"mixed_precision": "bf16",
"num_machines": 1,
"num_processes": 2,
"rdzv_backend": "static",
"same_network": true,
"tpu_env": [],
"tpu_use_cluster": false,
"tpu_use_sudo": false,
"use_cpu": false
}
24 changes: 24 additions & 0 deletions examples/train/tiled_mlp/train_deepspeed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
CUDA_VISIBLE_DEVICES=0,1 \
NPROC_PER_NODE=2 \
swift sft \
--model Qwen/Qwen3-4B \
--dataset swift/self-cognition#200 \
--train_type full \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--gradient_accumulation_steps 1 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 2048 \
--output_dir output \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--use_tiled_mlp true \
--tiled_mlp_num_shards 4 \
--deepspeed zero3
30 changes: 30 additions & 0 deletions examples/train/tiled_mlp/train_fsdp2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash
# FSDP2 training with tiled MLP
# Requires accelerate config with fsdp_version: 2

# First, create the accelerate config (fsdp2.json) or use the one in examples/train/multi-gpu/fsdp2_lora/

# FSDP2 with tiled MLP
accelerate launch --config_file fsdp2.json \
-m swift sft \
--model Qwen/Qwen3-4B \
--dataset swift/self-cognition#200 \
--train_type full \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--learning_rate 1e-5 \
--gradient_checkpointing false \
--weight_decay 0.1 \
--gradient_accumulation_steps 1 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 2048 \
--output_dir output \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--use_tiled_mlp true \
--tiled_mlp_num_shards 4
4 changes: 4 additions & 0 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def _prepare_generation_config(self):
@RayHelper.function(group='default')
def _prepare_model_tokenizer(self, **kwargs):
args = self.args
# Apply tiled MLP before model instantiation
if getattr(args, 'use_tiled_mlp', False):
from swift.plugin.tiled_mlp import apply_tiled_mlp
apply_tiled_mlp(args.model_type, num_shards=getattr(args, 'tiled_mlp_num_shards', None))
self.model, self.processor = args.get_model_processor(**kwargs)
if args.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
Expand Down
3 changes: 3 additions & 0 deletions swift/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .rm_plugin import rm_plugins
from .env import envs, Env
from .context_manager import context_managers, ContextManager
from .tiled_mlp import (TiledSwiGLUMLP, apply_tiled_mlp, is_fsdp2_enabled, is_fsdp1_enabled, get_tiled_mlp_mode)

else:
_import_structure = {
Expand All @@ -34,6 +35,8 @@
'rm_plugin': ['rm_plugins'],
'env': ['envs', 'Env'],
'context_manager': ['context_managers', 'ContextManager'],
'tiled_mlp':
['TiledSwiGLUMLP', 'apply_tiled_mlp', 'is_fsdp2_enabled', 'is_fsdp1_enabled', 'get_tiled_mlp_mode'],
}

import sys
Expand Down
Loading
Loading