Skip to content
Draft
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
4316425
wip
hjh0119 Aug 29, 2025
5d46eae
init wip
hjh0119 Sep 1, 2025
5828229
args wip
hjh0119 Sep 1, 2025
a82cec4
Merge remote-tracking branch 'origin/main' into mega-grpo
hjh0119 Sep 2, 2025
0689b76
reuse _prepare_rollout_engine
hjh0119 Sep 3, 2025
46593cf
merge main
hjh0119 Sep 11, 2025
3da8756
mega wip
hjh0119 Sep 12, 2025
2ca7ac1
Merge remote-tracking branch 'origin' into mega-grpo
hjh0119 Sep 17, 2025
d9ec029
wip
hjh0119 Sep 17, 2025
7c56f9f
override train_step wip
hjh0119 Sep 17, 2025
686fc74
remove override train_step to grpo
hjh0119 Sep 18, 2025
095bcbd
Merge remote-tracking branch 'origin' into mega-grpo
hjh0119 Sep 18, 2025
4d9457b
sync weight wip
hjh0119 Sep 18, 2025
f52d5e1
rollout wip
hjh0119 Sep 19, 2025
155d4fb
Merge remote-tracking branch 'origin' into mega-grpo
hjh0119 Sep 22, 2025
3c69c39
modify mini_batch_size to generation batch size
hjh0119 Sep 22, 2025
eebdd47
wip
hjh0119 Sep 24, 2025
de6ecfe
loss wip
hjh0119 Sep 28, 2025
4569e54
fix repeat n
hjh0119 Sep 28, 2025
f118935
Merge remote-tracking branch 'origin' into mega-grpo
hjh0119 Sep 29, 2025
9cb84e3
fix padding to multiple of tp_size
hjh0119 Sep 29, 2025
8627aa3
compute loss
hjh0119 Sep 29, 2025
2292cf8
fix logps
hjh0119 Sep 30, 2025
bbe5f39
logging & patch VL
hjh0119 Sep 30, 2025
6a2940c
fix rollout_group & rollout judgement
hjh0119 Oct 1, 2025
486c3d4
fix step
hjh0119 Oct 6, 2025
7e8e6b0
merge main
hjh0119 Oct 6, 2025
c68d976
move old base trainer to newer
hjh0119 Oct 7, 2025
6b1653c
fix
hjh0119 Oct 8, 2025
d4a9dcc
offload utils
hjh0119 Oct 8, 2025
9dc92a0
offload context
hjh0119 Oct 9, 2025
7bc3d61
Resolve merge conflict in megatron_args.py by removing duplicate fiel…
hjh0119 Oct 9, 2025
91f97ca
fix resolve
hjh0119 Oct 9, 2025
59f436c
fix logps
hjh0119 Oct 9, 2025
8dea6d7
fix old logps
hjh0119 Oct 9, 2025
abac696
reduce redundancy
hjh0119 Oct 9, 2025
3a3ff37
replace token
hjh0119 Oct 10, 2025
2cd89dc
fix offload model
hjh0119 Oct 10, 2025
50d5e6f
offload optimizer & ref
hjh0119 Oct 11, 2025
e1a06c6
support cp
hjh0119 Oct 11, 2025
ff9b667
fix pp+cp
hjh0119 Oct 11, 2025
ba4bfbf
lora wip
hjh0119 Oct 11, 2025
e5a6252
Merge remote-tracking branch 'origin' into mega-grpo
hjh0119 Oct 13, 2025
e22c790
arguments document
hjh0119 Oct 13, 2025
b3de262
wip lora&cp
hjh0119 Oct 14, 2025
d5bd92c
merge origin
hjh0119 Oct 14, 2025
fe3270f
remove unused patch
hjh0119 Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 182 additions & 4 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from transformers.utils.versions import require_version

from swift.llm.argument.base_args import to_abspath
from swift.utils import get_dist_setting, get_logger, json_parse_to_dict
from swift.utils import get_current_device, get_dist_setting, get_logger, is_master, json_parse_to_dict

logger = get_logger()


@dataclass
class RLHFMegatronArgumentsMixin:
rlhf_type: Literal['dpo', 'kto'] = None
rlhf_type: Literal['dpo', 'kto', 'grpo'] = None
perform_initialization: bool = True
ref_load: Optional[str] = None
ref_adapter_load: Optional[str] = None

Expand All @@ -33,6 +34,100 @@ class RLHFMegatronArgumentsMixin:
undesirable_weight: float = 1.
calculate_KL: Optional[bool] = None

# =========================== GRPO ===========================
generation_batch_size: Optional[int] = None
steps_per_generation: Optional[int] = None
num_generations: int = 8
max_completion_length: int = 512
# GSPO https://www.arxiv.org/abs/2507.18071
importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token'

# ─────────────────────────── Sampling ───────────────────────────
epsilon: float = 0.2
epsilon_high: Optional[float] = None
delta: Optional[float] = None
top_k: int = 50
top_p: float = 0.9
repetition_penalty: float = 1.
# ─────────────────────────── VLLM ───────────────────────────
use_vllm: bool = False
vllm_mode: Literal['server', 'colocate'] = 'colocate'
# ────────────── Internal VLLM (colocate) ──────────────
vllm_enable_prefix_caching: bool = True
vllm_gpu_memory_utilization: float = 0.9
vllm_tensor_parallel_size: int = 1
vllm_max_model_len: Optional[int] = None
vllm_enforce_eager: bool = False
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
vllm_disable_cascade_attn: bool = False
sleep_level: Literal[0, 1, 2] = 0

# ────────────── External VLLM (server, not supported yet) ──────────────
vllm_server_base_url: Optional[List[str]] = None
vllm_server_host: Optional[List[str]] = None
vllm_server_port: List[int] = field(default_factory=lambda: [8000])
vllm_server_timeout: float = 240.0
vllm_client: Optional[object] = field(init=False, default=None)

# ─────────────────────────── Reward ───────────────────────────
reward_funcs: List[str] = field(default_factory=list)
reward_weights: List[float] = None
# see details in swift/plugin/orm.py
# cosine reward, https://arxiv.org/abs/2502.03373
cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
# repetition penalty, https://arxiv.org/abs/2502.03373
repetition_n_grams: int = 3
repetition_max_penalty: float = -1.0
# soft_overlong, https://arxiv.org/abs/2503.14476
soft_max_length: Optional[int] = None
soft_cache_length: Optional[int] = None

# ─────────────────────────── Not Supported Yet ───────────────────────────
# reward model
reward_model: Optional[List[str]] = None
reward_model_plugin: Optional[List[str]] = None
# sync ref model
sync_ref_model: bool = False
ref_model_sync_steps: int = 512
ref_model_mixup_alpha: float = 0.6

async_generate: bool = False

move_model_batches: Optional[int] = None
offload_optimizer: bool = False
offload_model: bool = False
gc_collect_after_offload: bool = False # deprecated

# multi turn
multi_turn_func: Optional[str] = None # deprecated
multi_turn_scheduler: Optional[str] = None
max_turns: Optional[int] = None
completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round'
vllm_server_pass_dataset: bool = False

# DAPO, https://arxiv.org/abs/2503.14476
dynamic_sample: bool = False
max_resample_times: int = 3
overlong_filter: bool = False

# Dr. GRPO, https://arxiv.org/abs/2503.20783
scale_rewards: bool = True

# entropy
log_entropy: bool = False
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0

wandb_log_unique_prompts: Optional[bool] = None
num_iterations: int = 1

# dataset
dataset_shuffle: Optional[bool] = True

def _init_kto(self):
if self.calculate_KL is None:
# Not all losses require a KL calculation
Expand All @@ -43,11 +138,93 @@ def _init_kto(self):
def __post_init__(self):
if self.rlhf_type is None:
return
default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid'}
default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid', 'grpo': 'grpo'}
if self.loss_type is None:
self.loss_type = default_loss_type[self.rlhf_type]
if self.rlhf_type == 'kto':
self._init_kto()
if self.rlhf_type == 'grpo':
self._init_grpo()

def _init_grpo(self):

def _init_external_vllm():
if self.rlhf_type != 'grpo' or (self.vllm_server_host is None and self.vllm_server_base_url is None):
return
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
if is_master():
logger.info('Start connecting to vLLM server')
self.vllm_client = VLLMClient(
base_urls=self.vllm_server_base_url,
hosts=self.vllm_server_host,
server_ports=self.vllm_server_port,
connection_timeout=self.vllm_server_timeout)
self.vllm_client.init_communicator(device=get_current_device())
logger.info('Connected to vLLM server')

def _check_not_supported():
pass

def _check_batch_params():
if self.generation_batch_size is None and self.steps_per_generation is None:
self.steps_per_generation = 1
self.generation_batch_size = self.global_batch_size * self.steps_per_generation
elif self.generation_batch_size is not None and self.steps_per_generation is None:
# Just ensure the value is divisible by the global batch size
if self.generation_batch_size % self.global_batch_size != 0:
raise ValueError(f'generation_batch_size ({self.generation_batch_size}) '
f'must be divisible by the global batch size ({self.global_batch_size}).')
self.steps_per_generation = self.generation_batch_size // self.global_batch_size
elif self.generation_batch_size is None and self.steps_per_generation is not None:
self.generation_batch_size = self.global_batch_size * self.steps_per_generation
else:
raise ValueError(
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time")
world_size = torch.distributed.get_world_size()
assert self.generation_batch_size % world_size == 0, \
f'generation_batch_size ({self.generation_batch_size}) ' \
f'must be divisible by the world size ({world_size})'
self.per_device_generation_batch_size = self.generation_batch_size // world_size

_init_external_vllm()
_check_not_supported()
_check_batch_params()
# default loss_type if no loss_type is provided
assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \
f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}'
if self.async_generate or not self.use_vllm:
self.sleep_level = 0
self.remove_unused_columns = False
logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')
if self.truncation_strategy is None:
self.truncation_strategy = 'left'
assert self.truncation_strategy in ['left', 'delete'
], ("GRPO requires `truncation_strategy 'left' or 'delete'`, "
f"Current value: `truncation_strategy='{self.truncation_strategy}'`."
) # noqa
if self.beta is None:
self.beta = 0.04 # https://arxiv.org/abs/2402.03300
if self.async_generate:
logger.info('Using async mode. This is a approximate version which '
'will use the old weights to generate responses to accelerate. '
'This will ignore the `CLIP` of advantages, if you found the training '
'is unstable, you may consider using --async_generate false.')
if 'soft_overlong' in self.reward_funcs:
assert self.soft_cache_length is not None, \
'The soft_cache_length must be set when using soft overlong rewards.'
if self.soft_max_length is None:
self.soft_max_length = self.max_completion_length
logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}')
if self.use_vllm:
# set vllm mode
if self.vllm_server_host is not None or self.vllm_server_base_url is not None:
if self.vllm_mode != 'server':
self.vllm_mode = 'server'
logger.warning('set vllm_mode to `server` since vllm server host/base_url is provided')
else:
if self.vllm_mode != 'colocate':
self.vllm_mode = 'colocate'
logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided')


@dataclass
Expand Down Expand Up @@ -178,6 +355,7 @@ class MegatronArguments(ExtraMegatronArguments):
dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic'
manual_gc: bool = False
manual_gc_interval: int = 0
use_mbridge: bool = False

# learning rate
lr: Optional[float] = None
Expand Down Expand Up @@ -206,7 +384,7 @@ class MegatronArguments(ExtraMegatronArguments):
no_load_rng: bool = False
finetune: bool = False
ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist'
no_initialization: bool = True
no_initialization: bool = False
auto_detect_ckpt_format: bool = True
exit_on_missing_checkpoint: bool = True

Expand Down
2 changes: 1 addition & 1 deletion swift/megatron/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@dataclass
class MegatronRLHFArguments(MegatronTrainArguments):
rlhf_type: Literal['dpo', 'kto'] = 'dpo'
rlhf_type: Literal['dpo', 'kto', 'grpo'] = 'dpo'
loss_scale: str = 'last_round'

calculate_per_token_loss: bool = False
2 changes: 1 addition & 1 deletion swift/megatron/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from swift.llm.argument.base_args import to_abspath
from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master
from ..model import get_megatron_model_meta
from .megatron_args import MegatronArguments
from .megatron_args import MegatronArguments, RLHFMegatronArgumentsMixin

logger = get_logger()

Expand Down
17 changes: 12 additions & 5 deletions swift/megatron/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import List, Optional, Union

from swift.llm.train.kto import prepare_kto_dataset
from swift.trainers.rlhf_trainer.utils import identity_data_collator
from swift.utils import get_logger
from ..argument import MegatronRLHFArguments
from ..trainers import MegatronDPOTrainer, MegatronKTOTrainer
from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer, MegatronKTOTrainer
from .sft import MegatronSft

logger = get_logger()
Expand All @@ -18,6 +19,8 @@ def prepare_trainer(self):
args = self.args
if args.rlhf_type == 'dpo':
trainer_cls = MegatronDPOTrainer
elif args.rlhf_type == 'grpo':
trainer_cls = MegatronGRPOTrainer
elif args.rlhf_type == 'kto':
trainer_cls = MegatronKTOTrainer
else:
Expand All @@ -26,10 +29,14 @@ def prepare_trainer(self):

def _prepare_template(self) -> None:
super()._prepare_template()
if self.args.rlhf_type == 'kto':
self.template.set_mode('kto')
else:
self.template.set_mode('rlhf')
model_mapping = {'grpo': 'train', 'kto': 'kto'}
self.template.set_mode(model_mapping.get(self.args.rlhf_type, 'rlhf'))

def _get_data_collator(self):
if self.args.rlhf_type == 'grpo':
super()._get_data_collator()
return identity_data_collator
return super()._get_data_collator()

def _get_dataset(self):
args = self.args
Expand Down
1 change: 1 addition & 0 deletions swift/megatron/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .dpo_trainer import MegatronDPOTrainer
from .grpo_trainer import MegatronGRPOTrainer
from .kto_trainer import MegatronKTOTrainer
from .trainer import MegatronTrainer
11 changes: 8 additions & 3 deletions swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import datetime
from typing import Dict, Literal

Expand All @@ -27,8 +27,10 @@
from megatron.training.training import num_floating_point_operations
from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model
from packaging import version
from torch.distributed.nn import all_reduce
from transformers.utils import ContextManagers

from swift.llm import dynamic_gradient_checkpointing
from swift.llm import Template, dynamic_gradient_checkpointing
from swift.plugin import MeanMetric
from swift.trainers import SwiftMixin
from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger
Expand All @@ -41,7 +43,7 @@

class BaseMegatronTrainer(ABC):

def __init__(self, args, template):
def __init__(self, args, template: Template):
self.args = args
self.template = template
self.stimer = StragglerDetector()
Expand Down Expand Up @@ -70,9 +72,11 @@ def initialize_megatron(*_args, **kwargs):
args = get_args()
data_parallel_size = mpu.get_data_parallel_world_size()
step_batch_size = args.micro_batch_size * data_parallel_size
num_generations = args.num_generations if hasattr(args, 'num_generations') else 1
if args.train_iters is None and args.max_epochs is not None:
if hasattr(train_dataset, '__len__'):
dataset_sample = len(train_dataset) // step_batch_size * step_batch_size
dataset_sample = dataset_sample * num_generations
args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size
else:
raise ValueError(
Expand All @@ -82,6 +86,7 @@ def initialize_megatron(*_args, **kwargs):
args.eval_iters = 0
elif hasattr(val_dataset, '__len__'):
dataset_sample = len(val_dataset) // step_batch_size * step_batch_size
dataset_sample = dataset_sample * num_generations
args.eval_iters = max(dataset_sample // args.global_batch_size, 1)
else:
raise ValueError(
Expand Down
Loading
Loading