-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy patharguments.py
More file actions
629 lines (567 loc) · 37.1 KB
/
arguments.py
File metadata and controls
629 lines (567 loc) · 37.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import platform
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Union
from transformers.training_args import TrainingArguments as HfTrainingArguments
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments
from swift.plugin import loss_mapping
from swift.utils import get_dist_setting, get_logger, is_cce_available, is_liger_available, is_mp, json_parse_to_dict
from .optimizers.galore import GaLoreConfig
logger = get_logger()
@dataclass
class TrainArgumentsMixin:
"""A dataclass mixin for configuring model training parameters.
Args:
per_device_train_batch_size (int): The batch size per GPU/TPU core for training. Defaults to 1.
per_device_eval_batch_size (int): The batch size per GPU/TPU core for evaluation. Defaults to 1.
gradient_accumulation_steps (Optional[int]): The number of update steps to accumulate gradients for before
performing an optimizer step.
tuner_backend (Optional[str]): The backend to use for parameter-efficient fine-tuning (e.g., 'peft'). Defaults
to None.
gradient_checkpointing (bool): If True, use gradient checkpointing to save memory at the cost of a slower
backward pass. Defaults to True.
vit_gradient_checkpointing (Optional[bool]): A specific gradient checkpointing setting for the Vision
Transformer part of the model. Defaults to None.
gradient_checkpointing_kwargs (Optional[Union[dict, str]]): Keyword arguments for
`torch.utils.checkpoint.checkpoint`. Defaults to None.
logging_first_step (bool): Whether to log the first global step. Defaults to True.
logging_steps (int): Log every `logging_steps` global steps. Defaults to 5.
router_aux_loss_coef (float): The coefficient for the router auxiliary loss in Mixture-of-Experts models.
Defaults to 0.0.
enable_dft_loss (bool): Whether to enable Diversity-from-Diversity (DFD) loss.
See https://arxiv.org/abs/2508.05629. Defaults to False.
enable_channel_loss (bool): Whether to enable channel loss. Defaults to False.
weight_decay (float): The weight decay to apply (if not zero) to all layers except bias and LayerNorm weights.
Defaults to 0.1.
adam_beta2 (float): The beta2 hyperparameter for the AdamW optimizer. Defaults to 0.95.
lr_scheduler_type (str): The learning rate scheduler type to use. Defaults to 'cosine'.
lr_scheduler_kwargs (Optional[Union[dict, str]]): Additional keyword arguments for the learning rate scheduler,
passed as a JSON string or a dictionary. Defaults to None.
report_to (List[str]): The list of integrations to report results to (e.g., 'tensorboard', 'wandb'). Defaults
to ['tensorboard'].
dataloader_num_workers (Optional[int]): The number of subprocesses to use for data loading. Defaults to None.
dataloader_persistent_workers (bool): If True, the data loader workers will not be shut down after a dataset
has been consumed once. Defaults to False.
dataloader_prefetch_factor (Optional[int]): The number of batches loaded in advance by each worker. Defaults
to None.
use_liger_kernel (bool): Whether to use the Liger kernel for optimization. Defaults to False.
use_cce (bool): Whether to use ml-cross-entropy fused kernels for optimization. Defaults to False.
check_model (bool): If True, checks local model files for corruption or modification and provides a warning.
Should be set to False in an offline environment. Defaults to True.
acc_strategy (Literal['token', 'seq']): The strategy for calculating accuracy during training and validation.
Can be 'token' for token-level accuracy or 'seq' for sequence-level accuracy. Defaults to 'token'.
train_dataloader_shuffle (bool): Whether to shuffle the training data. Defaults to True.
max_epochs (Optional[int]): The total number of training epochs to perform. Overrides `num_train_epochs`.
Defaults to None.
aligner_lr (Optional[float]): A specific learning rate for the aligner part of the model. Defaults to None.
vit_lr (Optional[float]): A specific learning rate for the Vision Transformer part of the model. Defaults to
None.
use_logits_to_keep (Optional[bool]): If enabled, reduces VRAM usage and speeds up training by calculating and
storing only the necessary logits based on the labels during the forward pass. If None, the behavior is
automatically determined. Defaults to None.
ds3_gather_for_generation (bool): In DeepSpeed ZeRO-3, whether to gather model parameters for generation.
Defaults to True.
resume_only_model (bool): When resuming from a checkpoint, whether to load only the model weights and not the
optimizer/scheduler states. Defaults to False.
optimizer (Optional[str]): The name of a custom optimizer from a plugin. If None, a default optimizer is used.
See documentation for available choices. Defaults to None.
loss_type (Optional[str]): The name of a custom loss function from a plugin. If None, the model's default loss
function is used. Defaults to None.
metric (Optional[str]): The name of a custom metric from a plugin. If None, it defaults to 'nlg' when
`predict_with_generate=True`. Defaults to None.
eval_use_evalscope (bool): Whether to use EvalScope for evaluation during training. Must be set to `True` to
enable it. Refer to examples for usage details. Defaults to False.
eval_dataset (List[str]): A list of evaluation dataset names. Multiple datasets can be specified, separated
by spaces.
eval_dataset_args (Optional[Union[str, dict]]): Arguments for the evaluation dataset(s), provided as a JSON
string or a dictionary.
eval_limit (Optional[int]): The maximum number of samples to use from the evaluation dataset. Defaults to None.
eval_generation_config (Optional[Union[str, dict]]): Model inference configuration for evaluation, provided as
a JSON string or a dictionary, e.g., `{'max_tokens': 512}`. Defaults to None.
extra_eval_args (Optional[Union[str, dict]]): Extra arguments for evaluation, provided as a JSON string or a
dictionary.
use_flash_ckpt (bool): Whether to enable DLRover Flash Checkpoint. When enabled, weights are first saved to
shared memory and then asynchronously persisted to disk. Currently does not support the safetensors format.
It is recommended to use this with `PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"` to prevent CUDA OOM
errors during training. Defaults to False.
"""
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
gradient_accumulation_steps: Optional[int] = None
tuner_backend: Optional[str] = None
gradient_checkpointing: bool = True
vit_gradient_checkpointing: Optional[bool] = None
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
logging_first_step: bool = True
logging_steps: int = 5
router_aux_loss_coef: float = 0.
enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629
enable_channel_loss: bool = False
weight_decay: float = 0.1
adam_beta2: float = 0.95
lr_scheduler_type: str = 'cosine'
lr_scheduler_kwargs: Optional[Union[dict, str]] = None
report_to: List[str] = field(default_factory=lambda: ['tensorboard'])
dataloader_num_workers: Optional[int] = None
dataloader_persistent_workers: bool = False
dataloader_prefetch_factor: Optional[int] = None
use_liger_kernel: bool = False
use_cce: bool = False
# extra
check_model: bool = True
acc_strategy: Literal['token', 'seq'] = 'token'
train_dataloader_shuffle: bool = True
max_epochs: Optional[int] = None
aligner_lr: Optional[float] = None
vit_lr: Optional[float] = None
use_logits_to_keep: Optional[bool] = None
ds3_gather_for_generation: bool = True
resume_only_model: bool = False
optimizer: Optional[str] = None
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(loss_mapping.keys())}'})
metric: Optional[str] = None
# train-eval loop args
eval_use_evalscope: bool = False
eval_dataset: List[str] = field(default_factory=list)
eval_dataset_args: Optional[Union[str, dict]] = None
eval_limit: Optional[int] = None
eval_generation_config: Optional[Union[str, dict]] = None
extra_eval_args: Optional[Union[str, dict]] = None
# dlrover flash_checkpoint
use_flash_ckpt: bool = False
@staticmethod
def _patch_liger_kernel():
# fix logits_to_keep
from liger_kernel.transformers.model import loss_utils
origin_LigerForCausalLMLoss = loss_utils.LigerForCausalLMLoss
def LigerForCausalLMLoss(hidden_states, *args, **kwargs):
hidden_states = hidden_states.contiguous()
return origin_LigerForCausalLMLoss(hidden_states, *args, **kwargs)
loss_utils.LigerForCausalLMLoss = LigerForCausalLMLoss
logger.info('Patch liger_kernel successfully.')
def _init_liger(self):
if self.use_liger_kernel:
assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`'
try:
self._patch_liger_kernel()
except Exception:
pass
def _init_cce(self):
if self.use_cce:
assert is_cce_available(), 'use_cce requires cut-cross-entropy, try `pip install cut-cross-entropy`'
def __post_init__(self):
if is_mp() and self.use_liger_kernel:
raise ValueError('liger_kernel does not support device_map. '
'Please use DDP/DeepSpeed for multi-GPU training.')
if self.use_cce and self.use_liger_kernel:
logger.warning('Enabling both use_cce and use_liger_kernel may lead to duplicated kernel patches.')
if self.optimizer is None and (self.vit_lr is not None or self.aligner_lr is not None):
self.optimizer = 'multimodal'
if self.gradient_accumulation_steps is None:
world_size = get_dist_setting()[2]
self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size))
logger.info(f'Setting args.gradient_accumulation_steps: {self.gradient_accumulation_steps}')
if self.lr_scheduler_kwargs:
self.lr_scheduler_kwargs = json_parse_to_dict(self.lr_scheduler_kwargs)
if self.vit_gradient_checkpointing is None:
self.vit_gradient_checkpointing = self.gradient_checkpointing
if self.gradient_checkpointing_kwargs:
self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs)
self._init_liger()
self._init_cce()
if self.dataloader_num_workers is None:
if platform.system() == 'Windows':
self.dataloader_num_workers = 0
else:
self.dataloader_num_workers = 1
logger.info(f'Setting args.dataloader_num_workers: {self.dataloader_num_workers}')
if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0:
self.dataloader_prefetch_factor = 10
if self.eval_use_evalscope:
try:
import evalscope
except ImportError:
raise ImportError('evalscope is not installed, please install it by `pip install evalscope`')
self.eval_dataset_args = json_parse_to_dict(self.eval_dataset_args)
self.eval_generation_config = json_parse_to_dict(self.eval_generation_config)
self.extra_eval_args = json_parse_to_dict(self.extra_eval_args)
super().__post_init__()
@dataclass
class RLHFArgumentsMixin:
"""A dataclass mixin for GKD and CHORD training.
Args:
sft_alpha (float): The weight for the SFT loss component in GKD. The final loss is calculated as
`gkd_loss + sft_alpha * sft_loss`. Defaults to 0.
chord_sft_dataset (List[str]): The SFT dataset(s) used to provide expert data for the CHORD algorithm. Defaults
to `[]`.
chord_sft_per_device_train_batch_size (Optional[int]): The SFT mini-batch size per device for the CHORD
algorithm. Defaults to None.
chord_enable_phi_function (bool): Whether to enable the token-wise weighting function phi (φ) in the CHORD
algorithm. Defaults to False.
chord_mu_warmup_steps (Optional[int]): The number of training steps for the mu (μ) value to warm up to its peak
value. Defaults to None.
chord_mu_decay_steps (Optional[int]): The number of training steps for the mu (μ) value to decay from its peak
to its valley value. Defaults to None.
chord_mu_peak (Optional[float]): The peak value for mu (μ) during its schedule. Defaults to None.
chord_mu_valley (Optional[float]): The final (valley) value for mu (μ) after decay. Defaults to None.
"""
# gkd
sft_alpha: float = 0
# chord
chord_sft_dataset: List[str] = field(default_factory=list)
chord_sft_per_device_train_batch_size: Optional[int] = None
chord_enable_phi_function: bool = False
chord_mu_warmup_steps: Optional[int] = None
chord_mu_decay_steps: Optional[int] = None
chord_mu_peak: Optional[float] = None
chord_mu_valley: Optional[float] = None
@dataclass
class SwiftArgumentsMixin(RLHFArgumentsMixin, TrainArgumentsMixin):
"""A dataclass for configuring additional training parameters.
Args:
train_type (Optional[str]): The training type. Can be 'lora', 'full', 'longlora', 'adalora', 'llamapro',
'adapter', 'vera', 'boft', 'fourierft', or 'reft'. Defaults to 'lora'.
local_repo_path (Optional[str]): Path to a local repository. Some models (e.g., deepseek-vl2) depend on a
GitHub repo for loading. Using a local repo avoids network issues during 'git clone'. Defaults to None.
galore_config (Optional[GaLoreConfig]): GaLore configuration. Defaults to None.
padding_side (Optional[str]): The padding side for training when batch_size >= 2. Can be 'left' or 'right'.
Defaults to 'right'.
Note: For inference with batch_size >= 2, only left padding is performed. PPO and GKD default to 'left'.
padding_free (Optional[bool]): Whether to flatten data within a batch to avoid padding, reducing VRAM usage
and speeding up training. Sequences within the same batch remain isolated. Defaults to False. Currently
supports CPT, SFT, DPO, GRPO, KTO, and GKD.
Note: It is recommended to use this with '--attn_impl flash_attn' and 'transformers>=4.44'. Compared to
packing, padding_free has no preprocessing overhead, but packing is faster and has more stable VRAM usage.
task_type (Optional[str]): The type of task. Can be 'causal_lm', 'seq_cls', 'embedding', 'reranker', or
'generative_reranker'. Defaults to 'causal_lm'. If set to 'seq_cls', you usually need to also set
'--num_labels' and '--problem_type'.
problem_type (Optional[str]): Required for classification models (i.e., when task_type is 'seq_cls'). Can be
'regression', 'single_label_classification', or 'multi_label_classification'. Defaults to None, which is
resolved to 'regression' if the model is a reward_model or num_labels is 1, and
'single_label_classification' otherwise.
"""
# Value copied from TrainArguments
train_type: Optional[str] = None
local_repo_path: Optional[str] = None
galore_config: Optional[GaLoreConfig] = None
padding_side: Optional[str] = None
padding_free: Optional[bool] = None
task_type: Optional[str] = None
problem_type: Optional[str] = None
def __post_init__(self):
if hasattr(self, 'output_dir'):
self.output_dir = os.path.abspath(os.path.expanduser(self.output_dir))
super().__post_init__()
@dataclass
class VllmArguments:
"""VllmArguments is a dataclass that holds the configuration for vllm.
Args:
vllm_gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
vllm_tensor_parallel_size (int): Tensor parallelism size. Default is 1.
vllm_pipeline_parallel_size (int): Pipeline parallelism size. Default is 1.
vllm_enable_expert_parallel (bool): Flag to enable expert parallelism for MoE models. Default is False.
vllm_max_num_seqs (int): Maximum number of sequences. Default is 256.
vllm_max_model_len (Optional[int]): Maximum model length. Default is None.
vllm_disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is True.
vllm_enforce_eager (bool): Flag to enforce eager execution. Default is False.
vllm_limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None.
vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16.
vllm_enable_prefix_caching (Optional[bool]): Flag to enable automatic prefix caching. Default is None.
vllm_use_async_engine (bool): Whether to use async engine for vLLM. Default is False.
vllm_quantization (Optional[str]): The quantization method for vLLM. Default is None.
vllm_reasoning_parser (Optional[str]): The reasoning parser for vLLM. Default is None.
vllm_disable_cascade_attn (bool): Flag to disable cascade attention. Default is False.
vllm_mm_processor_cache_gb (Optional[float]): MM processor cache size in GB. Default is None.
vllm_speculative_config (Optional[Union[dict, str]]): Speculative decoding configuration, passed in as a JSON
string. Defaults to None.
vllm_engine_kwargs (Optional[Union[dict, str]]): Additional parameters for vllm, formatted as a JSON string.
Defaults to None.
vllm_data_parallel_size (int): Data parallelism size for vLLM rollout. Default is 1.
"""
# vllm
vllm_gpu_memory_utilization: float = 0.9
vllm_tensor_parallel_size: int = 1
vllm_pipeline_parallel_size: int = 1
vllm_enable_expert_parallel: bool = False
vllm_max_num_seqs: int = 256
vllm_max_model_len: Optional[int] = None
vllm_disable_custom_all_reduce: bool = True
vllm_enforce_eager: bool = False
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
vllm_max_lora_rank: int = 16
vllm_enable_prefix_caching: Optional[bool] = None
vllm_use_async_engine: bool = False
vllm_quantization: Optional[str] = None
vllm_reasoning_parser: Optional[str] = None
vllm_disable_cascade_attn: bool = False
vllm_mm_processor_cache_gb: Optional[float] = None
vllm_speculative_config: Optional[Union[dict, str]] = None
vllm_engine_kwargs: Optional[Union[dict, str]] = None
# rollout
vllm_data_parallel_size: int = 1
def __post_init__(self):
if self.vllm_limit_mm_per_prompt is not None:
self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt)
if self.vllm_speculative_config is not None:
self.vllm_speculative_config = json_parse_to_dict(self.vllm_speculative_config)
self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs)
def get_vllm_engine_kwargs(self):
adapters = self.adapters
if hasattr(self, 'adapter_mapping'):
adapters = adapters + list(self.adapter_mapping.values())
kwargs = {
'gpu_memory_utilization': self.vllm_gpu_memory_utilization,
'tensor_parallel_size': self.vllm_tensor_parallel_size,
'pipeline_parallel_size': self.vllm_pipeline_parallel_size,
'enable_expert_parallel': self.vllm_enable_expert_parallel,
'max_num_seqs': self.vllm_max_num_seqs,
'max_model_len': self.vllm_max_model_len,
'disable_custom_all_reduce': self.vllm_disable_custom_all_reduce,
'enforce_eager': self.vllm_enforce_eager,
'limit_mm_per_prompt': self.vllm_limit_mm_per_prompt,
'max_lora_rank': self.vllm_max_lora_rank,
'enable_lora': len(adapters) > 0,
'max_loras': max(len(adapters), 1),
'enable_prefix_caching': self.vllm_enable_prefix_caching,
'use_async_engine': self.vllm_use_async_engine,
'quantization': self.vllm_quantization,
'reasoning_parser': self.vllm_reasoning_parser,
'disable_cascade_attn': self.vllm_disable_cascade_attn,
'mm_processor_cache_gb': self.vllm_mm_processor_cache_gb,
'speculative_config': self.vllm_speculative_config,
'num_labels': self.num_labels,
'engine_kwargs': self.vllm_engine_kwargs,
}
if self.task_type in ('embedding', 'seq_cls') or 'reranker' in self.task_type:
kwargs['task_type'] = self.task_type
return kwargs
@dataclass
class RolloutTrainerArgumentsMixin(VllmArguments):
"""A dataclass for configuring parameters required for GRPO training rollout.
This mixin provides arguments for controlling the generation process during rollout, especially when using vLLM as
the inference backend for generation.
Args:
top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 50.
top_p (float): If set to a float < 1, only the smallest set of most probable tokens with probabilities that
add up to top_p or higher are kept for generation. Defaults to 0.9.
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
stop_words (List[str]): A list of strings that will stop the generation when they are generated. Defaults to an
empty list.
use_vllm (bool): Whether to use vLLM as the inference backend for GRPO generation. Defaults to False.
vllm_mode (Literal['server', 'colocate']): The vLLM integration mode. 'server' mode uses a vLLM server launched
by swift rollout for sampling. 'colocate' mode deploys vLLM within the same process. For full-parameter
training in 'server' mode, the environment variable `SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE` can be set to
control the bucket size for weight synchronization (in MB, defaults to 512). Defaults to 'colocate'.
vllm_enable_prefix_caching (bool): A pass-through parameter for vLLM. Enables prefix caching. Used in
'colocate' mode. Defaults to True.
vllm_enable_lora (bool): Enables the vLLM Engine to load LoRA adapters. This is used to accelerate weight
synchronization during LoRA training. See documentation for details. Used in 'colocate' mode. Defaults
to False.
lora_rank (int): The rank for the LoRA adapter loaded by the vLLM engine. When using `vllm_enable_lora`, this
should be greater than or equal to (ideally equal to) the rank of the LoRA model being trained. Used in
'colocate' mode. Defaults to 8.
vllm_server_base_url (Optional[List[str]]): The base URL of the vLLM server (e.g., "http://localhost:8000").
If set, `vllm_server_host` and `vllm_server_port` are ignored. Used in 'server' mode. Defaults to None.
vllm_server_host (Optional[List[str]]): The host address(es) of the vLLM
server. Used in 'server' mode. Defaults to None.
vllm_server_port (List[int]): The port(s) of the vLLM server. Used in 'server' mode. Defaults to `[8000]`.
vllm_server_timeout (float): Timeout in seconds for connecting to the vLLM server. Used in 'server' mode.
Defaults to 240.0.
vllm_client: Internal client instance for vLLM server communication. Not intended to be set by the user.
async_generate (bool): Whether to perform asynchronous rollout to improve training speed. Note: When enabled,
sampling uses the model from the previous update step and is not compatible with multi-turn scenarios.
Defaults to False.
sleep_level (int): Specifies the level of GPU memory release for vLLM during training steps. Options: 0
(no release), 1, 2. A higher level releases more memory but may incur overhead. Defaults to 0.
move_model_batches (Optional[int]): The number of batches after which the model is moved back to the GPU if it
was offloaded. Used for memory management during training. Defaults to None.
offload_optimizer (bool): Whether to offload optimizer states to CPU/RAM when performing inference with vLLM to
save GPU memory. Defaults to False.
offload_model (bool): Whether to offload the model weights to CPU/RAM when performing inference with vLLM.
Defaults to False.
wandb_log_unique_prompts (Optional[bool]): Whether to log unique prompts to Weights & Biases for analysis
during training. Defaults to None.
"""
# generation args
top_k: int = 50
top_p: float = 0.9
repetition_penalty: float = 1.
stop_words: List[str] = field(default_factory=list)
# vllm
use_vllm: bool = False
vllm_mode: Literal['server', 'colocate'] = 'colocate'
# internal vllm (colocate)
vllm_max_num_seqs: Optional[int] = None
vllm_enable_prefix_caching: bool = True # overwrite
vllm_enable_lora: bool = False
lora_rank: int = 8 # for vllm lora adapter
# external vllm (server)
vllm_server_base_url: Optional[List[str]] = None
vllm_server_host: Optional[List[str]] = None
vllm_server_port: List[int] = field(default_factory=lambda: [8000])
vllm_server_timeout: float = 240.0
vllm_client = None # Not required to set, used for client instantiation
vllm_server_group_port: List[int] = field(default_factory=lambda: [51216])
enable_flattened_weight_sync: bool = True
async_generate: bool = False
sleep_level: int = 0
move_model_batches: Optional[int] = None
offload_optimizer: bool = False
offload_model: bool = False
wandb_log_unique_prompts: Optional[bool] = None
@dataclass
class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
"""A dataclass for configuring parameters for algorithms like DAPO, Dr.GRPO, GSPO, RLOO, and REINFORCE++.
Args:
epsilon (float): The clipping coefficient. Defaults to 0.2.
epsilon_high (Optional[float]): The upper clipping coefficient. If set, it forms a clipping range of
`[epsilon, epsilon_high]` with epsilon. Defaults to None.
delta (Optional[float]): The upper clipping value for two-sided GRPO from the INTELLECT-2 tech
report. If set, it is recommended to be greater than `1 + epsilon`. Defaults to None.
cosine_min_len_value_wrong (float): The reward for wrong answers with zero completion length
(r^w_0 in the paper). Defaults to -0.5.
cosine_max_len_value_wrong (float): The reward for wrong answers with maximum completion length
(r^w_L in the paper). Defaults to 0.0.
cosine_min_len_value_correct (float): The reward for correct answers with zero completion length
(r^c_0 in the paper). Defaults to 1.0.
cosine_max_len_value_correct (float): The reward for correct answers with maximum completion length
(r^c_L in the paper). Defaults to 0.5.
cosine_max_len (Optional[int]): The maximum length for generated text (Lmax in the paper). Defaults
to `max_completion_length`.
repetition_n_grams (int): The n-gram size for repetition detection. Defaults to 3.
repetition_max_penalty (float): The maximum penalty value, used to control the strength of the
penalty. Defaults to -1.0.
reward_model (Optional[List[str]]): The reward model(s) to use. Defaults to None.
reward_model_plugin (Optional[List[str]]): The plugin logic for the reward model. Defaults to 'orm'
logic. See custom reward models for details. Defaults to None.
sync_ref_model (bool): Whether to periodically synchronize the `ref_model`. Defaults to False.
ref_model_sync_steps (int): The synchronization frequency. Defaults to 512.
ref_model_mixup_alpha (float): Controls the mixup between the current model and the previous
`ref_model` during updates. Defaults to 0.6.
multi_turn_scheduler (Optional[str]): Parameter for multi-turn GRPO. Pass the corresponding plugin
name. The implementation should be added in `plugin/multi_turn.py`. Defaults to None.
max_turns (Optional[int]): The maximum number of turns for multi-turn GRPO. If None, there is no
limit. Defaults to None.
completion_length_limit_scope (Literal['total', 'per_round']): The scope of the
`max_completion_length` limit in multi-turn dialogue. 'total' limits the total output length across all
turns, while 'per_round' limits the output length for each turn. Defaults to 'per_round'.
vllm_server_pass_dataset (bool): Pass extra dataset information to the vLLM server, used for
multi-turn training. Defaults to False.
dynamic_sample (bool): If True, filters out data with a reward standard deviation of 0 within a group
and samples new data. Defaults to False.
max_resample_times (int): When `dynamic_sample` is enabled, this limits the number of resampling
attempts. Defaults to 3.
overlong_filter (bool): If True, skips samples that are truncated due to being too long, so they are
not included in the loss calculation. Defaults to False.
soft_max_length (Optional[int]): The maximum generation length of the model (L_max in the paper).
Defaults to `max_completion_length`.
soft_cache_length (Optional[int]): Controls the length penalty interval (L_cache in the paper).
The interval is `[soft_max_length - soft_cache_length, soft_max_length]`. Defaults to None.
scale_rewards (Optional[Literal['group', 'batch', 'none']]): Specifies the reward scaling strategy.
Options are 'group' (scale by intra-group standard deviation), 'batch' (scale by the entire batch's
standard deviation), or 'none' (no scaling). The default value is tied to `advantage_estimator`: 'group'
for 'grpo', 'none' for 'rloo', and 'batch' for 'reinforce_plus_plus'. In ms-swift < 3.10, this was a
boolean where `True` corresponded to 'group' and `False` to 'none'.
log_entropy (bool): Log the dynamics of entropy values during training. See documentation for
details. Defaults to False.
top_entropy_quantile (float): Only tokens with entropy in the top specified quantile participate in
the loss calculation. A value of 1.0 means no tokens are filtered. See documentation for details.
Defaults to 1.0.
importance_sampling_level (Literal['token', 'sequence', 'sequence_token']): Controls the importance
sampling ratio calculation. 'token' mode retains the original log probability ratio for each token.
'sequence' mode averages the log probability ratios of all valid tokens in the sequence. The GSPO paper
uses 'sequence' level for stable training. Defaults to 'token'.
tau_pos (float): The temperature parameter for positive dominance in the SAPO algorithm, controlling the
sharpness of the soft gating function. Larger values result in sharper gating (approaching hard
clipping), while smaller values result in smoother gating. The default value is 1.0.
tau_neg (float): The temperature parameter for negative dominance in the SAPO algorithm, controlling the
sharpness of the soft gating function. Typically, `tau_neg` is set > `tau_pos` to impose stronger
constraints on negative dominance. The default value is 1.05.
advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus']): The advantage estimation
function to use. 'grpo' calculates the relative advantage within a group. Options are 'grpo', 'rloo',
'reinforce_plus_plus'. Defaults to 'grpo'.
kl_in_reward (Optional[bool]): Controls how the KL divergence regularization term is handled. If
`False`, it's an independent term in the loss function. If `True`, KL is directly incorporated into the
reward (subtracted from it). The default is tied to `advantage_estimator`: `False` for 'grpo', `True` for
'rloo' and 'reinforce_plus_plus'.
generation_batch_size (Optional[int]): The batch size for sampling completions. It should be a
multiple of `num_processes * per_device_train_batch_size`. Defaults to `per_device_batch_size *
gradient_accumulation_steps * num_processes`.
steps_per_generation (Optional[int]): The number of optimization steps per generation round. Only
one of `steps_per_generation` and `generation_batch_size` can be set. Defaults to
`gradient_accumulation_steps`.
dataset_shuffle (Optional[bool]): Whether to shuffle the dataset. Defaults to True.
rollout_importance_sampling_mode (Optional[Literal['token_truncate', 'token_mask', 'sequence_truncate',
'sequence_mask']]): The training-pull inconsistency correction mode. Options are `token_truncate`,
`token_mask`, `sequence_truncate`, and `sequence_mask`. Defaults to None, disabling correction.
See the documentation for details.
rollout_importance_sampling_threshold (float): The threshold for importance sampling weights, used to truncate
or mask extreme weights. Defaults to 2.0.
log_rollout_offpolicy_metrics (bool): Whether to log rollout off-policy diagnostic metrics (KL, PPL, chi2, etc.)
when `rollout_importance_sampling_mode` is not set. When `rollout_importance_sampling_mode` is set,
metrics are always logged regardless of this setting. Defaults to False.
"""
epsilon: float = 0.2
epsilon_high: Optional[float] = None
delta: Optional[float] = None
# reward function args, see details in swift/plugin/orm.py
# cosine reward, https://arxiv.org/abs/2502.03373
cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
# repetition penalty, https://arxiv.org/abs/2502.03373
repetition_n_grams: int = 3
repetition_max_penalty: float = -1.0
reward_model: Optional[List[str]] = None
reward_model_plugin: Optional[List[str]] = None
# sync ref model
sync_ref_model: bool = False
ref_model_sync_steps: int = 512
ref_model_mixup_alpha: float = 0.6
# multi turn
multi_turn_scheduler: Optional[str] = None
max_turns: Optional[int] = None
completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round'
vllm_server_pass_dataset: bool = False
# DAPO, https://arxiv.org/abs/2503.14476
dynamic_sample: bool = False
max_resample_times: int = 3
overlong_filter: bool = False
soft_max_length: Optional[int] = None
soft_cache_length: Optional[int] = None
# Dr. GRPO, https://arxiv.org/abs/2503.20783
scale_rewards: Optional[Literal['group', 'batch', 'none']] = None
# entropy
log_entropy: bool = False
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0
# GSPO https://arxiv.org/abs/2507.18071
importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token'
# SAPO https://arxiv.org/abs/2511.20347
# Temperature parameters for soft adaptive gate
tau_pos: float = 1.0
tau_neg: float = 1.05
# RLOO, REINFORCE++
advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus'] = 'grpo'
# If false, add KL into loss, otherwise add into reward
kl_in_reward: Optional[bool] = None # rloo/reinforce_plus_plus: true, grpo: false (default)
generation_batch_size: Optional[int] = None
steps_per_generation: Optional[int] = None
# dataset
dataset_shuffle: Optional[bool] = True
# Rollout Importance Sampling Correction (off-policy correction)
# Set to None to disable, or choose from: 'token_truncate', 'token_mask', 'sequence_truncate', 'sequence_mask'
rollout_importance_sampling_mode: Optional[Literal['token_truncate', 'token_mask', 'sequence_truncate',
'sequence_mask']] = None
rollout_importance_sampling_threshold: float = 2.0 # Threshold for truncation/masking (C in paper)
log_rollout_offpolicy_metrics: bool = False # Log off-policy metrics even when IS correction is disabled
@dataclass
class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
pass
@dataclass
class Seq2SeqTrainingArguments(SwiftArgumentsMixin, HfSeq2SeqTrainingArguments):
pass