Skip to content

Commit 59cd472

Browse files
authored
SP cu_seqlens fix, refactor (axolotl-ai-cloud#2495)
* working on masking fix * refactor and fix multipack seqlens * pre-commit fix * adding smoke test * using existing packed seqlens util * log warning re: logged losses / gradient scaling per rank
1 parent 9b89591 commit 59cd472

File tree

8 files changed

+150
-178
lines changed

8 files changed

+150
-178
lines changed

src/axolotl/core/trainers/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ def _prepare_dataloader(
235235
self.accelerator.even_batches = False
236236

237237
# Return unprepared dataloader if using sequence parallelism
238+
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
239+
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
240+
# slice each batch along the sequence dimension).
238241
if self.args.sequence_parallel_degree > 1:
239242
return dataloader
240243

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,22 @@
11
"""Module for Axolotl trainer sequence parallelism mixin"""
22

33
import logging
4-
from typing import Any
54

6-
import torch
75
import torch.distributed as dist
8-
import torch.nn.functional as F
96
from datasets import Dataset
10-
from torch import nn
117
from torch.utils.data import DistributedSampler, Sampler
128

139
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
1410

1511
LOG = logging.getLogger(__name__)
1612

17-
try:
18-
from ring_flash_attn import update_ring_flash_attn_params
19-
except ImportError:
20-
# We pass silently here, but raise an ImportError in our Axolotl config validation
21-
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
22-
pass
23-
2413

2514
class SequenceParallelMixin:
2615
"""
2716
Mixin class for sequence parallelism support in trainers.
2817
2918
This mixin provides functionality for handling sequence parallelism,
30-
including creating appropriate samplers, managing data partitioning,
31-
and updating ring flash attention parameters during training.
19+
specifically for creating appropriate data samplers.
3220
"""
3321

3422
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -99,84 +87,3 @@ def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
9987
return self._create_sequence_parallel_sampler(
10088
eval_dataset, shuffle=False, is_eval=True
10189
)
102-
103-
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
104-
"""
105-
Calculate the cu_seqlens for the current forward pass and pass the value to
106-
the substituted ring_flash_attn. This is accomplished by using the passed
107-
`input_ids`.
108-
109-
Args:
110-
inputs: Current batch of inputs.
111-
"""
112-
# At this point, inputs should already be partitioned by the sequence
113-
# parallel data collator
114-
batch_size = inputs["input_ids"].shape[0]
115-
seq_len = inputs["input_ids"].shape[1]
116-
packed_seq_lens = [seq_len] * batch_size
117-
118-
# Calculate the full sequence length across all GPUs in this SP group
119-
total_seq_len = seq_len * self.args.sequence_parallel_degree
120-
121-
cu_seqlens = torch.cumsum(
122-
torch.tensor(
123-
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
124-
),
125-
dim=-1,
126-
dtype=torch.int32,
127-
)
128-
cu_seqlens = F.pad(
129-
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
130-
)
131-
132-
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
133-
134-
def training_step(
135-
self,
136-
model: nn.Module,
137-
inputs: dict[str, torch.Tensor | Any],
138-
num_items_in_batch: int | None = None,
139-
) -> torch.Tensor:
140-
"""
141-
Perform a training step on a batch of inputs. Overrides the
142-
`transformers.trainer.Trainer` method to handle sequence parallelism if
143-
enabled.
144-
145-
Args:
146-
model: Model to perform training step for.
147-
inputs: Dictionary mapping.
148-
"""
149-
# Set up sequence parallelism for this step if enabled
150-
if self.args.sequence_parallel_degree > 1:
151-
self._update_ring_flash_attn_params(inputs)
152-
153-
# Proceed with normal training step
154-
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
155-
156-
def prediction_step(
157-
self,
158-
model: nn.Module,
159-
inputs: dict[str, torch.Tensor | Any],
160-
prediction_loss_only: bool,
161-
ignore_keys: list[str] | None = None,
162-
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
163-
"""
164-
Perform a prediction step on a batch of inputs. Overrides the
165-
`transformers.trainer.Trainer` method to handle sequence parallelism if
166-
enabled.
167-
168-
Args:
169-
model: Model to perform prediction step for.
170-
inputs: Dictionary mapping of inputs.
171-
prediction_loss_only: Whether to return only the loss.
172-
ignore_keys: Keys to ignore in the inputs.
173-
174-
Returns:
175-
Tuple of (loss, logits, labels).
176-
"""
177-
# Set up sequence parallelism for this prediction step if enabled
178-
if self.args.sequence_parallel_degree > 1:
179-
self._update_ring_flash_attn_params(inputs)
180-
181-
# Proceed with normal prediction step
182-
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

src/axolotl/monkeypatch/attention/ring_attn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
their sequence parallel version of Flash Attention 2.
77
"""
88

9+
import torch
910
import torch.distributed as dist
1011
from accelerate.logging import get_logger
1112

1213
from axolotl.logging_config import configure_logging
14+
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
1315

1416
configure_logging()
1517
LOG = get_logger(__name__)
@@ -98,3 +100,27 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
98100
substitute_hf_flash_attn(
99101
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
100102
)
103+
104+
105+
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
106+
"""
107+
Calculate the cumulative sequence lengths for the current forward pass and pass the
108+
value to the substituted `ring_flash_attn`.
109+
110+
Args:
111+
batch: A dictionary with a batch of data. May or may not contain `position_ids`
112+
data; if not, we compute it.
113+
"""
114+
from ring_flash_attn import update_ring_flash_attn_params
115+
116+
input_ids = batch["input_ids"]
117+
position_ids = batch.get("position_ids")
118+
if position_ids is None:
119+
seq_len = input_ids.shape[1]
120+
position_ids = torch.arange(
121+
0, seq_len, dtype=torch.long, device=input_ids.device
122+
).unsqueeze(0)
123+
124+
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
125+
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
126+
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

src/axolotl/monkeypatch/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def get_cu_seqlens(attn_mask):
9696
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
9797

9898

99-
def get_cu_seqlens_from_pos_ids(position_ids):
99+
def get_cu_seqlens_from_pos_ids(
100+
position_ids: torch.Tensor,
101+
) -> tuple[torch.Tensor, torch.Tensor]:
100102
"""generate a cumulative sequence length mask for flash attention using pos ids"""
101103
if len(position_ids.shape) == 1:
102104
position_ids = position_ids.unsqueeze(0)

src/axolotl/utils/collators/batching.py

Lines changed: 12 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
includes logic for handling sequence parallelism collation.
44
"""
55

6-
import logging
76
from dataclasses import dataclass
87
from typing import Any, Optional, Union
98

@@ -13,46 +12,7 @@
1312
from transformers import PreTrainedTokenizerBase
1413
from transformers.utils import PaddingStrategy
1514

16-
logger = logging.getLogger(__name__)
17-
18-
19-
def adjust_position_ids_for_slice(
20-
position_ids: torch.Tensor, start_idx: int
21-
) -> torch.Tensor:
22-
"""
23-
Adjust position IDs for a sliced sequence to maintain proper relative positions.
24-
This handles the case where position IDs might not be contiguous due to sample
25-
packing.
26-
"""
27-
# Convert to tensor if not already
28-
# Find the boundaries between samples (where position_ids reset)
29-
adjusted_pos_ids = position_ids.clone()
30-
31-
# Process each sequence in the batch
32-
for i in range(position_ids.shape[0]):
33-
seq = position_ids[i]
34-
35-
# Find sample boundaries
36-
boundaries = []
37-
for j in range(1, len(seq)):
38-
if seq[j] < seq[j - 1]:
39-
boundaries.append(j)
40-
41-
# No need to adjust if there are no boundaries or this is a single sample
42-
if not boundaries:
43-
adjusted_pos_ids[i] = seq - start_idx
44-
continue
45-
46-
# Adjust each segment separately
47-
prev_boundary = 0
48-
for boundary in boundaries:
49-
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
50-
prev_boundary = boundary
51-
52-
# Last segment
53-
adjusted_pos_ids[i, prev_boundary:] -= start_idx
54-
55-
return adjusted_pos_ids
15+
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
5616

5717

5818
@dataclass
@@ -196,23 +156,20 @@ def apply_sequence_parallelism(
196156
Returns:
197157
Sliced batch dictionary.
198158
"""
199-
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
159+
# Get local (start, end) for sequence parallelism slicing
160+
total_seq_len = batch["input_ids"].shape[1]
161+
slice_size = total_seq_len // self.local_world_size
162+
start = self.local_rank * slice_size
163+
end = start + slice_size
164+
165+
# Update params for ring attention calculation
166+
update_ring_attn_params(batch=batch)
200167

168+
# Slice batch for sequence parallel processing
169+
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
201170
for key in keys_to_slice:
202171
if key in batch:
203-
seq_len = batch[key].shape[1]
204-
slice_size = seq_len // self.local_world_size
205-
start_idx = self.local_rank * slice_size
206-
end_idx = (
207-
start_idx + slice_size
208-
if self.local_rank < self.local_world_size - 1
209-
else seq_len
210-
)
211-
batch[key] = batch[key][:, start_idx:end_idx]
212-
213-
# Special handling for position_ids
214-
if key == "position_ids" and self.local_rank > 0:
215-
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
172+
batch[key] = batch[key][:, start:end]
216173

217174
return batch
218175

src/axolotl/utils/schemas/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,12 @@ def check_sequence_parallel_degree(cls, value, info):
11561156
"flash_attention: true must be set with sequence_parallel_degree > 1"
11571157
)
11581158

1159+
if not info.data["micro_batch_size"] == 1:
1160+
raise ValueError(
1161+
"micro_batch_size must be set to 1 "
1162+
"due to a `ring-flash-attn` requirement"
1163+
)
1164+
11591165
try:
11601166
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
11611167
except ImportError as exception:
@@ -1165,6 +1171,18 @@ def check_sequence_parallel_degree(cls, value, info):
11651171
"or `pip install ring-flash-attn>=0.1.4`."
11661172
) from exception
11671173

1174+
# TODO: monkeypatch / callback to average losses correctly across SP ranks
1175+
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
1176+
# according to the proportion of non-padding tokens per rank.
1177+
LOG.warning(
1178+
"Sequence parallelism (SP) is enabled with "
1179+
f"sequence_parallel_degree={value}. Please note that logged losses may "
1180+
"differ slightly to the non-SP losses due to transformers Trainer "
1181+
"implementation details. Please see "
1182+
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
1183+
"for more details."
1184+
)
1185+
11681186
return value
11691187

11701188
@model_validator(mode="before")

tests/e2e/multigpu/test_sp.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""E2E tests for sequence parallelism"""
2+
3+
import os
4+
from pathlib import Path
5+
6+
import yaml
7+
from accelerate.test_utils import execute_subprocess_async
8+
from transformers.testing_utils import get_torch_dist_unique_port
9+
10+
from axolotl.utils.dict import DictDefault
11+
12+
from ..utils import check_tensorboard
13+
14+
os.environ["WANDB_DISABLED"] = "true"
15+
16+
17+
class TestSequenceParallelism:
18+
"""Test case for training with sequence parallelism enabled"""
19+
20+
def test_sequence_parallel_training(self, temp_dir):
21+
# pylint: disable=duplicate-code
22+
cfg = DictDefault(
23+
{
24+
"base_model": "HuggingFaceTB/SmolLM2-135M",
25+
"load_in_8bit": False,
26+
"load_in_4bit": True,
27+
"strict": False,
28+
"sequence_len": 2048,
29+
"adapter": "qlora",
30+
"sample_packing": True,
31+
"eval_sample_packing": True,
32+
"pad_to_sequence_len": True,
33+
"lora_r": 8,
34+
"lora_alpha": 16,
35+
"lora_dropout": 0.05,
36+
"lora_target_linear": True,
37+
"lora_modules_to_save": ["embed_tokens", "lm_head"],
38+
"special_tokens": {"pad_token": "<|endoftext|>"},
39+
"datasets": [
40+
{
41+
"path": "tatsu-lab/alpaca",
42+
"type": "alpaca",
43+
},
44+
],
45+
"num_epochs": 1,
46+
"max_steps": 8,
47+
"micro_batch_size": 1,
48+
"gradient_accumulation_steps": 2,
49+
"output_dir": temp_dir,
50+
"learning_rate": 0.00001,
51+
"optimizer": "adamw_8bit",
52+
"lr_scheduler": "cosine",
53+
"flash_attention": True,
54+
"loss_watchdog_threshold": 5.0,
55+
"loss_watchdog_patience": 3,
56+
"bf16": "auto",
57+
"warmup_steps": 1,
58+
"saves_per_epoch": 1,
59+
"logging_steps": 1,
60+
"weight_decay": 0.0,
61+
"use_tensorboard": True,
62+
"sequence_parallel_degree": 2,
63+
}
64+
)
65+
66+
# write cfg to yaml file
67+
Path(temp_dir).mkdir(parents=True, exist_ok=True)
68+
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
69+
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
70+
71+
execute_subprocess_async(
72+
[
73+
"accelerate",
74+
"launch",
75+
"--num-processes",
76+
"2",
77+
"--main_process_port",
78+
f"{get_torch_dist_unique_port()}",
79+
"-m",
80+
"axolotl.cli.train",
81+
str(Path(temp_dir) / "config.yaml"),
82+
]
83+
)
84+
85+
check_tensorboard(
86+
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
87+
)

0 commit comments

Comments
 (0)