Skip to content

Commit ae1c7ac

Browse files
authored
Sequence parallel training context manager (axolotl-ai-cloud#2553)
* ctx manager for SP * updates * update * further simplifying * accommodate both training context managers * simplifying * simplifying * nit * reorg * tweak codecov yaml * add gather post hook, simplify, fixes * pytest * pytest fix
1 parent 1447beb commit ae1c7ac

File tree

12 files changed

+610
-209
lines changed

12 files changed

+610
-209
lines changed

codecov.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
codecov:
22
require_ci_to_pass: yes
3+
notify:
4+
wait_for_ci: true
35

46
coverage:
57
precision: 2

src/axolotl/core/trainer_builder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -932,9 +932,6 @@ def build_collator(
932932
collator = DataCollatorForSeq2Seq
933933

934934
kwargs["return_tensors"] = "pt"
935-
if issubclass(collator, DataCollatorForSeq2Seq):
936-
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
937-
kwargs["ring_attn_func"] = training_args.ring_attn_func
938935

939936
return collator(
940937
*collator_args,

src/axolotl/core/trainers/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,15 @@ def compute_loss(
371371
num_items_in_batch=num_items_in_batch,
372372
)
373373

374-
return super().compute_loss(
374+
loss = super().compute_loss(
375375
model,
376376
inputs,
377377
return_outputs=return_outputs,
378378
num_items_in_batch=num_items_in_batch,
379379
)
380380

381+
return loss
382+
381383
@staticmethod
382384
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
383385
concatenated_batch = {}

src/axolotl/core/trainers/mixins/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
from .optimizer import OptimizerMixin
77
from .rng_state_loader import RngLoaderMixin
88
from .scheduler import SchedulerMixin
9-
from .sequence_parallel import SequenceParallelMixin
9+
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin

src/axolotl/core/trainers/mixins/sequence_parallel.py

Lines changed: 226 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,86 @@
1-
"""Module for Axolotl trainer sequence parallelism mixin"""
1+
"""
2+
Module for Axolotl trainer sequence parallelism mixin and training context manager
3+
"""
24

5+
import functools
36
import logging
47

8+
import torch
59
import torch.distributed as dist
610
from datasets import Dataset
11+
from torch import nn
712
from torch.utils.data import DistributedSampler, Sampler
13+
from torch.utils.hooks import RemovableHandle
814

9-
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
15+
from axolotl.monkeypatch.attention.ring_attn import (
16+
RingAttnFunc,
17+
get_ring_attn_group,
18+
update_ring_attn_params,
19+
)
1020

1121
LOG = logging.getLogger(__name__)
1222

1323

24+
def apply_sequence_parallelism(
25+
batch: dict[str, torch.Tensor],
26+
local_rank: int,
27+
local_world_size: int,
28+
ring_attn_func: RingAttnFunc,
29+
) -> dict[str, torch.Tensor]:
30+
"""
31+
Apply sequence parallelism slicing to a batch.
32+
33+
Args:
34+
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
35+
local_rank: Local rank in the sequence parallel group
36+
local_world_size: World size of the sequence parallel group
37+
ring_attn_func: The ring attention function to use
38+
39+
Returns:
40+
Sliced batch dictionary.
41+
"""
42+
# Update ring attention params if needed
43+
if batch.get("position_ids") is not None:
44+
update_ring_attn_params(position_ids=batch["position_ids"])
45+
46+
# Slice batch for sequence parallel processing
47+
total_seq_len = batch["input_ids"].size(1)
48+
for key in batch:
49+
if (
50+
key in batch
51+
and isinstance(batch[key], torch.Tensor)
52+
and batch[key].dim() > 1
53+
and batch[key].size(1) == total_seq_len
54+
):
55+
56+
if ring_attn_func in [
57+
RingAttnFunc.VARLEN_LLAMA3,
58+
RingAttnFunc.BATCH_RING,
59+
]:
60+
# Split in sequential fashion and grab this rank's chunk
61+
batch[key] = (
62+
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
63+
)
64+
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
65+
chunks = batch[key].chunk(2 * local_world_size, dim=1)
66+
67+
# Take rank's chunk and opposing chunk for zigzag pattern
68+
selected_chunks = [
69+
chunks[local_rank],
70+
chunks[2 * local_world_size - local_rank - 1],
71+
]
72+
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
73+
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
74+
# Split into striped data and stack
75+
tensor = torch.stack(
76+
batch[key].split(local_world_size, dim=1),
77+
dim=1,
78+
).transpose(1, 2)
79+
batch[key] = tensor[:, local_rank].contiguous()
80+
81+
return batch
82+
83+
1484
class SequenceParallelMixin:
1585
"""
1686
Mixin class for sequence parallelism support in trainers.
@@ -87,3 +157,157 @@ def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
87157
return self._create_sequence_parallel_sampler(
88158
eval_dataset, shuffle=False, is_eval=True
89159
)
160+
161+
162+
class SequenceParallelContextManager:
163+
"""
164+
Context manager for sequence parallelism operations.
165+
166+
This class provides a context that will automatically apply sequence parallelism
167+
during model forward passes using a pre-forward hook, and gather outputs from
168+
across the sequence parallelism group using a post-forward hook.
169+
"""
170+
171+
def __init__(
172+
self,
173+
model: nn.Module,
174+
sequence_parallel_degree: int,
175+
ring_attn_func: RingAttnFunc,
176+
):
177+
self.model = model
178+
self.sequence_parallel_degree = sequence_parallel_degree
179+
self.ring_attn_func = ring_attn_func
180+
self.process_group = get_ring_attn_group()
181+
182+
# Initialize sequence parallel group details
183+
self.local_rank = dist.get_rank(self.process_group)
184+
self.local_world_size = dist.get_world_size(self.process_group)
185+
186+
# Will store hook handles for removal
187+
self.hook_handles: list[RemovableHandle] = []
188+
189+
# Create a partially applied version of the apply_sequence_parallelism function
190+
# with pre-configured params
191+
self.apply_sequence_parallelism = functools.partial(
192+
apply_sequence_parallelism,
193+
local_rank=self.local_rank,
194+
local_world_size=self.local_world_size,
195+
ring_attn_func=self.ring_attn_func,
196+
)
197+
198+
def __enter__(self):
199+
# Forward pre-hook to apply sequence parallelism
200+
def sequence_parallel_pre_hook(_, args, kwargs):
201+
# Apply sequence parallelism to kwargs
202+
kwargs = self.apply_sequence_parallelism(batch=kwargs)
203+
return args, kwargs
204+
205+
# Forward post-hook to gather outputs
206+
def sequence_parallel_post_hook(_, __, output):
207+
# Gather the sharded outputs
208+
return self.gather_outputs(output)
209+
210+
# Register both hooks
211+
self.hook_handles.append(
212+
self.model.register_forward_pre_hook(
213+
sequence_parallel_pre_hook, with_kwargs=True
214+
)
215+
)
216+
self.hook_handles.append(
217+
self.model.register_forward_hook(sequence_parallel_post_hook)
218+
)
219+
220+
return self
221+
222+
def __exit__(self, exc_type, exc_val, exc_tb):
223+
# Remove all hooks
224+
for handle in self.hook_handles:
225+
handle.remove()
226+
self.hook_handles = []
227+
228+
def gather_outputs(self, output):
229+
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
230+
# Handle different output formats (dict, tensor, etc.)
231+
if isinstance(output, dict):
232+
gathered_output = {}
233+
for key, value in output.items():
234+
if isinstance(value, torch.Tensor) and value.dim() > 1:
235+
# Gather logits or other sequence-sharded tensors
236+
gathered_value = self.gather_tensor(value)
237+
gathered_output[key] = gathered_value
238+
else:
239+
gathered_value = value.clone()
240+
dist.all_reduce(
241+
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
242+
)
243+
gathered_output[key] = gathered_value
244+
return gathered_output
245+
if isinstance(output, torch.Tensor):
246+
return self.gather_tensor(output)
247+
248+
return output
249+
250+
def gather_tensor(self, tensor):
251+
"""Gather a sharded tensor from all ranks."""
252+
# Prepare tensors for all_gather
253+
world_size = self.local_world_size
254+
255+
# Create list to store tensors from all ranks
256+
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
257+
258+
# All-gather operation
259+
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
260+
261+
# Concatenate along sequence dimension (typically dim=1)
262+
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
263+
# Simple concatenation for standard sharding
264+
return torch.cat(gathered_tensors, dim=1)
265+
266+
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
267+
# Each rank has a pattern of (rank, world_size*2-rank-1)
268+
reconstituted_tensors = [None] * (world_size * 2)
269+
270+
# First, split each gathered tensor into its two chunks
271+
for rank, gathered_tensor in enumerate(gathered_tensors):
272+
# Each tensor contains two chunks in the sequence dimension
273+
chunk_size = gathered_tensor.size(1) // 2
274+
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
275+
276+
# Place chunks in their original positions
277+
reconstituted_tensors[rank] = chunk1
278+
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
279+
280+
# Concatenate the reconstituted tensors in the correct order
281+
return torch.cat(reconstituted_tensors, dim=1)
282+
283+
# Otherwise, RingAttnFunc.BATCH_STRIPE
284+
# In striping, each rank has every world_size-th slice
285+
batch_size = tensor.size(0)
286+
hidden_dim = tensor.size(-1)
287+
288+
# First, determine the full sequence length
289+
total_seq_len = 0
290+
for t in gathered_tensors:
291+
total_seq_len += t.size(1)
292+
293+
# Create a tensor to hold the unstriped result
294+
result = torch.zeros(
295+
batch_size,
296+
total_seq_len,
297+
hidden_dim,
298+
dtype=tensor.dtype,
299+
device=tensor.device,
300+
)
301+
302+
# For each rank's tensor, distribute its slices to the correct positions
303+
for rank, gathered_tensor in enumerate(gathered_tensors):
304+
# The rank's tensor contains every world_size-th slice
305+
# starting from its rank position
306+
seq_len = gathered_tensor.size(1)
307+
for i in range(seq_len):
308+
# Calculate the position in the full tensor
309+
pos = i * world_size + rank
310+
if pos < total_seq_len:
311+
result[:, pos] = gathered_tensor[:, i]
312+
313+
return result

src/axolotl/train.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import signal
77
import sys
88
import weakref
9+
from contextlib import nullcontext
910
from pathlib import Path
1011
from typing import Any, Dict
1112

@@ -25,6 +26,9 @@
2526
fix_untrained_tokens,
2627
)
2728
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
29+
from axolotl.core.trainers.mixins.sequence_parallel import (
30+
SequenceParallelContextManager,
31+
)
2832
from axolotl.logging_config import configure_logging
2933
from axolotl.utils.dict import DictDefault
3034
from axolotl.utils.distributed import cleanup_distributed
@@ -185,16 +189,28 @@ def execute_training(
185189
trainer: The configured trainer object.
186190
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
187191
"""
188-
LOG.info("Starting trainer...")
189-
if cfg.flash_optimum:
190-
with torch.backends.cuda.sdp_kernel(
191-
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
192+
# Define the context managers to use
193+
flash_context = (
194+
torch.backends.cuda.sdp_kernel(
192195
enable_flash=True,
193196
enable_math=True,
194197
enable_mem_efficient=True,
195-
):
196-
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
197-
else:
198+
)
199+
if cfg.flash_optimum
200+
else nullcontext()
201+
)
202+
sequence_parallel_context = (
203+
SequenceParallelContextManager(
204+
model=trainer.model,
205+
sequence_parallel_degree=cfg.sequence_parallel_degree,
206+
ring_attn_func=cfg.ring_attn_func,
207+
)
208+
if cfg.sequence_parallel_degree > 1
209+
else nullcontext()
210+
)
211+
212+
LOG.info("Starting trainer...")
213+
with flash_context, sequence_parallel_context:
198214
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
199215

200216

0 commit comments

Comments
 (0)