Skip to content

Commit 80304c2

Browse files
authored
SP GRPO support + batch SP fixes (axolotl-ai-cloud#2643)
* ctx manager for SP * updates * update * further simplifying * simplifying * simplifying * reorg * batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * fix * fixes for batch API funcs, simplify * fix * grpo sp support * progress * stronger subclassing of TRL GRPO trainer; custom distributed sampler * subclassing constructor * progress * finalizing SP + GRPO trainer * minimize diffs to GRPO trainer * remove (most of) the custom GRPO trainer logic * debug * debug * update * update * update * progress * cleanup * cleanup * minor changes * update * update * update * small changes * updates * cleanup; torch.compile ring_flash_attn functions to prevent numerical instability; lint * spacing * cleanup; log in pydantic model config only on main process * remove comment * fix sp sampler, update to latest upstream code, doc * add docs * update quartodoc autodoc contents * fix, simplifications * fixes + simplifications * review comments * lint * removing main process only logs in favor of axolotl-ai-cloud#2608 * fixes, additional smoke test * updatse * more tests * update * fix grad accum bug (sort of) * lint, tests * todo
1 parent 67c4ea9 commit 80304c2

File tree

27 files changed

+1454
-461
lines changed

27 files changed

+1454
-461
lines changed

_quarto.yml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,23 @@ quartodoc:
4848
contents:
4949
- core.trainers.base
5050
- core.trainers.trl
51+
- core.trainers.mamba
52+
- core.trainers.relora
5153
- core.trainers.dpo.trainer
5254
- core.trainers.grpo.trainer
55+
- core.trainers.grpo.sampler
56+
- core.trainers.utils
57+
- title: Mixins
58+
desc: Mixin classes for augmenting trainers
59+
contents:
60+
- core.trainers.mixins.optimizer
61+
- core.trainers.mixins.rng_state_loader
62+
- core.trainers.mixins.scheduler
63+
- core.trainers.mixins.sequence_parallel
64+
- title: Context Managers
65+
desc: Context managers for altering trainer behaviors
66+
contents:
67+
- utils.ctx_managers.sequence_parallel
5368
- title: Prompt Strategies
5469
desc: Prompt formatting strategies
5570
contents:
@@ -86,7 +101,7 @@ quartodoc:
86101
- kernels.swiglu
87102
- kernels.quantize
88103
- kernels.utils
89-
- title: MonkeyPatches
104+
- title: Monkey Patches
90105
desc: Runtime patches for model optimizations
91106
contents:
92107
- monkeypatch.llama_attn_hijack_flash

docs/sequence_parallelism.qmd

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ title: Sequence Parallelism
33
description: Train with long sequences split across multiple GPUs.
44
---
55

6-
# Sequence Parallelism
7-
86
Sequence parallelism is a technique that splits sequences across multiple GPUs,
97
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
108
GPU processes a different portion of the sequence, and the results are aggregated
@@ -27,7 +25,7 @@ To enable sequence parallelism, add the following to your configuration file:
2725
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
2826
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
2927
heads_k_stride: 1
30-
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
28+
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
3129
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
3230
ring_attn_func:
3331
```

src/axolotl/common/datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from axolotl.utils.data.rl import load_prepare_preference_datasets
1515
from axolotl.utils.dict import DictDefault
1616
from axolotl.utils.models import load_processor, load_tokenizer
17+
from axolotl.utils.schemas.enums import RLType
1718
from axolotl.utils.tokenization import check_dataset_labels
1819

1920
LOG = logging.getLogger(__name__)
@@ -133,7 +134,7 @@ def load_preference_datasets(
133134
total_num_steps: Optional[int] = int(
134135
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
135136
)
136-
if cfg.rl == "grpo":
137+
if cfg.rl is RLType.GRPO:
137138
total_num_steps = None
138139

139140
if cli_args.debug or cfg.debug:

src/axolotl/core/trainer_builder.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
)
8888
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
8989
from axolotl.utils.models import ensure_dtype
90-
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
90+
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
9191

9292
try:
9393
import torch._dynamo # pylint: disable=ungrouped-imports
@@ -353,7 +353,7 @@ def build(self, total_num_steps):
353353
training_arguments_kwargs["warmup_steps"] = warmup_steps
354354
training_arguments_kwargs["logging_steps"] = logging_steps
355355

356-
if self.cfg.seed:
356+
if self.cfg.seed is not None:
357357
training_arguments_kwargs["seed"] = self.cfg.seed
358358

359359
if self.cfg.gradient_checkpointing:
@@ -547,8 +547,6 @@ def build(self, total_num_steps):
547547
report_to = []
548548
if self.cfg.use_wandb:
549549
report_to.append("wandb")
550-
if self.cfg.wandb_name:
551-
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
552550
if self.cfg.use_mlflow:
553551
report_to.append("mlflow")
554552
if self.cfg.use_tensorboard:
@@ -821,14 +819,15 @@ def build(self, total_num_steps):
821819
data_collator_kwargs = {
822820
"padding": True, # True/"longest" is the default
823821
}
822+
multiple = 64
824823
if self.cfg.pad_to_sequence_len:
825-
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
826-
self.cfg.sequence_len / 64
824+
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
825+
self.cfg.sequence_len / multiple
827826
)
828827
else:
829828
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
830829
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
831-
data_collator_kwargs["pad_to_multiple_of"] = 64
830+
data_collator_kwargs["pad_to_multiple_of"] = multiple
832831

833832
if self.cfg.reward_model:
834833
data_collator_kwargs["max_length"] = self.cfg.sequence_len
@@ -1034,6 +1033,10 @@ def build_training_arguments(self, total_num_steps):
10341033
training_args_kwargs["dataloader_prefetch_factor"] = (
10351034
self.cfg.dataloader_prefetch_factor
10361035
)
1036+
1037+
if self.cfg.seed is not None:
1038+
training_args_kwargs["seed"] = self.cfg.seed
1039+
10371040
if self.cfg.gradient_checkpointing:
10381041
training_args_kwargs["gradient_checkpointing"] = (
10391042
self.cfg.gradient_checkpointing
@@ -1076,23 +1079,27 @@ def build_training_arguments(self, total_num_steps):
10761079
if self.cfg.use_wandb:
10771080
training_args_kwargs["run_name"] = self.cfg.wandb_name
10781081

1082+
training_args_kwargs["sequence_parallel_degree"] = (
1083+
self.cfg.sequence_parallel_degree
1084+
)
1085+
10791086
training_args_cls = None
10801087
blocklist_args_kwargs = []
1081-
if self.cfg.rl == "simpo":
1088+
if self.cfg.rl is RLType.SIMPO:
10821089
training_args_cls = AxolotlCPOConfig
10831090
training_args_kwargs["loss_type"] = "simpo"
10841091
training_args_kwargs["max_length"] = self.cfg.sequence_len
10851092
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
10861093
if self.cfg.cpo_alpha is not None:
10871094
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
10881095

1089-
elif self.cfg.rl == "orpo":
1096+
elif self.cfg.rl is RLType.ORPO:
10901097
training_args_cls = AxolotlORPOConfig
10911098
training_args_kwargs["max_length"] = self.cfg.sequence_len
10921099
if self.cfg.max_prompt_len:
10931100
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
10941101

1095-
elif self.cfg.rl == "kto":
1102+
elif self.cfg.rl is RLType.KTO:
10961103
training_args_cls = AxolotlKTOConfig
10971104

10981105
training_args_kwargs["desirable_weight"] = (
@@ -1106,14 +1113,14 @@ def build_training_arguments(self, total_num_steps):
11061113
if self.cfg.max_prompt_len:
11071114
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
11081115

1109-
elif self.cfg.rl == "grpo":
1116+
elif self.cfg.rl is RLType.GRPO:
11101117
training_args_cls = GRPOStrategy.get_training_args_class()
11111118
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
11121119
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
11131120

11141121
else:
11151122
training_args_cls = AxolotlDPOConfig
1116-
if self.cfg.rl == "ipo":
1123+
if self.cfg.rl is RLType.IPO:
11171124
training_args_kwargs["loss_type"] = "ipo"
11181125
training_args_kwargs["max_length"] = self.cfg.sequence_len
11191126
training_args_kwargs["max_completion_length"] = None
@@ -1156,67 +1163,69 @@ def build_training_arguments(self, total_num_steps):
11561163

11571164
def build(self, total_num_steps):
11581165
training_args = self.build_training_arguments(total_num_steps)
1159-
dpo_trainer_kwargs = {}
1160-
if self.cfg.rl == "ipo":
1166+
trainer_kwargs = {}
1167+
if self.cfg.rl is RLType.IPO:
11611168
if self.cfg.dpo_label_smoothing:
1162-
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
1169+
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
11631170
if self.eval_dataset:
1164-
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
1171+
trainer_kwargs["eval_dataset"] = self.eval_dataset
11651172
if self.cfg.adapter and self.peft_config:
1166-
dpo_trainer_kwargs["peft_config"] = self.peft_config
1173+
trainer_kwargs["peft_config"] = self.peft_config
11671174
if self.cfg.precompute_ref_log_probs is not None:
1168-
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
1175+
trainer_kwargs["precompute_ref_log_probs"] = (
11691176
self.cfg.precompute_ref_log_probs
11701177
)
1171-
if self.cfg.rl == "grpo":
1172-
trainer_cls = GRPOStrategy.get_trainer_class()
1178+
if self.cfg.rl is RLType.GRPO:
1179+
trainer_cls = GRPOStrategy.get_trainer_class(
1180+
sequence_parallel=self.cfg.sequence_parallel_degree > 1
1181+
)
11731182
trainer_cls_args = [self.model]
11741183
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
1175-
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
1176-
elif self.cfg.rl in ["dpo", "ipo"]:
1184+
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
1185+
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
11771186
trainer_cls = DPOStrategy.get_trainer_class()
11781187
trainer_cls_args = [self.model, self.model_ref]
1179-
elif self.cfg.rl == "orpo":
1188+
elif self.cfg.rl is RLType.ORPO:
11801189
trainer_cls = AxolotlORPOTrainer
11811190
trainer_cls_args = [self.model]
1182-
elif self.cfg.rl in ["kto"]:
1191+
elif self.cfg.rl is RLType.KTO:
11831192
trainer_cls = AxolotlKTOTrainer
11841193
trainer_cls_args = [self.model]
1185-
elif self.cfg.rl in ["simpo"]:
1194+
elif self.cfg.rl is RLType.SIMPO:
11861195
trainer_cls = AxolotlCPOTrainer
11871196
trainer_cls_args = [self.model]
11881197
else:
11891198
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
11901199

11911200
sig = inspect.signature(trainer_cls)
11921201
if "tokenizer" in sig.parameters.keys():
1193-
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
1202+
trainer_kwargs["tokenizer"] = self.tokenizer
11941203
else:
1195-
dpo_trainer_kwargs["processing_class"] = self.tokenizer
1204+
trainer_kwargs["processing_class"] = self.tokenizer
11961205

11971206
if self.cfg.datasets is not None and (
11981207
trainer_cls is DPOStrategy.get_trainer_class()
11991208
):
1200-
dpo_trainer_kwargs["dataset_tags"] = [
1209+
trainer_kwargs["dataset_tags"] = [
12011210
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
12021211
]
1203-
dpo_trainer = trainer_cls(
1212+
trainer = trainer_cls(
12041213
*trainer_cls_args,
12051214
args=training_args,
12061215
train_dataset=self.train_dataset,
12071216
callbacks=self.get_callbacks(),
1208-
**dpo_trainer_kwargs,
1217+
**trainer_kwargs,
12091218
)
12101219
if self.cfg.fsdp:
1211-
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
1212-
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
1213-
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
1220+
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
1221+
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
1222+
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
12141223

1215-
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
1216-
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
1217-
dpo_trainer.add_callback(callback)
1224+
trainer = self.hook_post_create_trainer(trainer)
1225+
for callback in self.get_post_trainer_create_callbacks(trainer):
1226+
trainer.add_callback(callback)
12181227

1219-
return dpo_trainer
1228+
return trainer
12201229

12211230

12221231
class HFPPOTrainerBuilder(TrainerBuilderBase):

src/axolotl/core/trainers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .base import AxolotlTrainer
77
from .dpo.trainer import AxolotlDPOTrainer
8-
from .grpo.trainer import AxolotlGRPOTrainer
8+
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
99
from .mamba import AxolotlMambaTrainer
1010
from .relora import ReLoRATrainer
1111
from .trl import (

src/axolotl/core/trainers/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,13 @@ def compute_loss(
373373
num_items_in_batch=num_items_in_batch,
374374
)
375375

376-
loss = super().compute_loss(
376+
return super().compute_loss(
377377
model,
378378
inputs,
379379
return_outputs=return_outputs,
380380
num_items_in_batch=num_items_in_batch,
381381
)
382382

383-
return loss
384-
385383
@staticmethod
386384
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
387385
concatenated_batch = {}

src/axolotl/core/trainers/dpo/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
"""
2-
DPO Specific Strategy for training
3-
"""
1+
"""DPO Specific Strategy for training"""
42

53
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
4+
from axolotl.utils.schemas.enums import RLType
65

76

87
class DPOStrategy:
9-
"""
10-
Strategy for DPO training
11-
"""
8+
"""Strategy for DPO training"""
129

1310
@classmethod
1411
def get_trainer_class(cls):
@@ -23,7 +20,7 @@ def get_training_args_class(cls):
2320
@classmethod
2421
def set_training_args_kwargs(cls, cfg):
2522
training_args_kwargs = {}
26-
if cfg.rl == "ipo":
23+
if cfg.rl is RLType.IPO:
2724
training_args_kwargs["loss_type"] = "ipo"
2825
training_args_kwargs["max_length"] = cfg.sequence_len
2926
training_args_kwargs["max_completion_length"] = None

0 commit comments

Comments
 (0)