|
1 |
| -"""Module for Axolotl trainer sequence parallelism mixin""" |
| 1 | +""" |
| 2 | +Module for Axolotl trainer sequence parallelism mixin and training context manager |
| 3 | +""" |
2 | 4 |
|
| 5 | +import functools |
3 | 6 | import logging
|
4 | 7 |
|
| 8 | +import torch |
5 | 9 | import torch.distributed as dist
|
6 | 10 | from datasets import Dataset
|
| 11 | +from torch import nn |
7 | 12 | from torch.utils.data import DistributedSampler, Sampler
|
| 13 | +from torch.utils.hooks import RemovableHandle |
8 | 14 |
|
9 |
| -from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group |
| 15 | +from axolotl.monkeypatch.attention.ring_attn import ( |
| 16 | + RingAttnFunc, |
| 17 | + get_ring_attn_group, |
| 18 | + update_ring_attn_params, |
| 19 | +) |
10 | 20 |
|
11 | 21 | LOG = logging.getLogger(__name__)
|
12 | 22 |
|
13 | 23 |
|
| 24 | +def apply_sequence_parallelism( |
| 25 | + batch: dict[str, torch.Tensor], |
| 26 | + local_rank: int, |
| 27 | + local_world_size: int, |
| 28 | + ring_attn_func: RingAttnFunc, |
| 29 | +) -> dict[str, torch.Tensor]: |
| 30 | + """ |
| 31 | + Apply sequence parallelism slicing to a batch. |
| 32 | +
|
| 33 | + Args: |
| 34 | + batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) |
| 35 | + local_rank: Local rank in the sequence parallel group |
| 36 | + local_world_size: World size of the sequence parallel group |
| 37 | + ring_attn_func: The ring attention function to use |
| 38 | +
|
| 39 | + Returns: |
| 40 | + Sliced batch dictionary. |
| 41 | + """ |
| 42 | + # Update ring attention params if needed |
| 43 | + if batch.get("position_ids") is not None: |
| 44 | + update_ring_attn_params(position_ids=batch["position_ids"]) |
| 45 | + |
| 46 | + # Slice batch for sequence parallel processing |
| 47 | + total_seq_len = batch["input_ids"].size(1) |
| 48 | + for key in batch: |
| 49 | + if ( |
| 50 | + key in batch |
| 51 | + and isinstance(batch[key], torch.Tensor) |
| 52 | + and batch[key].dim() > 1 |
| 53 | + and batch[key].size(1) == total_seq_len |
| 54 | + ): |
| 55 | + |
| 56 | + if ring_attn_func in [ |
| 57 | + RingAttnFunc.VARLEN_LLAMA3, |
| 58 | + RingAttnFunc.BATCH_RING, |
| 59 | + ]: |
| 60 | + # Split in sequential fashion and grab this rank's chunk |
| 61 | + batch[key] = ( |
| 62 | + batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() |
| 63 | + ) |
| 64 | + elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: |
| 65 | + chunks = batch[key].chunk(2 * local_world_size, dim=1) |
| 66 | + |
| 67 | + # Take rank's chunk and opposing chunk for zigzag pattern |
| 68 | + selected_chunks = [ |
| 69 | + chunks[local_rank], |
| 70 | + chunks[2 * local_world_size - local_rank - 1], |
| 71 | + ] |
| 72 | + batch[key] = torch.cat(selected_chunks, dim=1).contiguous() |
| 73 | + elif ring_attn_func is RingAttnFunc.BATCH_STRIPE: |
| 74 | + # Split into striped data and stack |
| 75 | + tensor = torch.stack( |
| 76 | + batch[key].split(local_world_size, dim=1), |
| 77 | + dim=1, |
| 78 | + ).transpose(1, 2) |
| 79 | + batch[key] = tensor[:, local_rank].contiguous() |
| 80 | + |
| 81 | + return batch |
| 82 | + |
| 83 | + |
14 | 84 | class SequenceParallelMixin:
|
15 | 85 | """
|
16 | 86 | Mixin class for sequence parallelism support in trainers.
|
@@ -87,3 +157,157 @@ def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
87 | 157 | return self._create_sequence_parallel_sampler(
|
88 | 158 | eval_dataset, shuffle=False, is_eval=True
|
89 | 159 | )
|
| 160 | + |
| 161 | + |
| 162 | +class SequenceParallelContextManager: |
| 163 | + """ |
| 164 | + Context manager for sequence parallelism operations. |
| 165 | +
|
| 166 | + This class provides a context that will automatically apply sequence parallelism |
| 167 | + during model forward passes using a pre-forward hook, and gather outputs from |
| 168 | + across the sequence parallelism group using a post-forward hook. |
| 169 | + """ |
| 170 | + |
| 171 | + def __init__( |
| 172 | + self, |
| 173 | + model: nn.Module, |
| 174 | + sequence_parallel_degree: int, |
| 175 | + ring_attn_func: RingAttnFunc, |
| 176 | + ): |
| 177 | + self.model = model |
| 178 | + self.sequence_parallel_degree = sequence_parallel_degree |
| 179 | + self.ring_attn_func = ring_attn_func |
| 180 | + self.process_group = get_ring_attn_group() |
| 181 | + |
| 182 | + # Initialize sequence parallel group details |
| 183 | + self.local_rank = dist.get_rank(self.process_group) |
| 184 | + self.local_world_size = dist.get_world_size(self.process_group) |
| 185 | + |
| 186 | + # Will store hook handles for removal |
| 187 | + self.hook_handles: list[RemovableHandle] = [] |
| 188 | + |
| 189 | + # Create a partially applied version of the apply_sequence_parallelism function |
| 190 | + # with pre-configured params |
| 191 | + self.apply_sequence_parallelism = functools.partial( |
| 192 | + apply_sequence_parallelism, |
| 193 | + local_rank=self.local_rank, |
| 194 | + local_world_size=self.local_world_size, |
| 195 | + ring_attn_func=self.ring_attn_func, |
| 196 | + ) |
| 197 | + |
| 198 | + def __enter__(self): |
| 199 | + # Forward pre-hook to apply sequence parallelism |
| 200 | + def sequence_parallel_pre_hook(_, args, kwargs): |
| 201 | + # Apply sequence parallelism to kwargs |
| 202 | + kwargs = self.apply_sequence_parallelism(batch=kwargs) |
| 203 | + return args, kwargs |
| 204 | + |
| 205 | + # Forward post-hook to gather outputs |
| 206 | + def sequence_parallel_post_hook(_, __, output): |
| 207 | + # Gather the sharded outputs |
| 208 | + return self.gather_outputs(output) |
| 209 | + |
| 210 | + # Register both hooks |
| 211 | + self.hook_handles.append( |
| 212 | + self.model.register_forward_pre_hook( |
| 213 | + sequence_parallel_pre_hook, with_kwargs=True |
| 214 | + ) |
| 215 | + ) |
| 216 | + self.hook_handles.append( |
| 217 | + self.model.register_forward_hook(sequence_parallel_post_hook) |
| 218 | + ) |
| 219 | + |
| 220 | + return self |
| 221 | + |
| 222 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 223 | + # Remove all hooks |
| 224 | + for handle in self.hook_handles: |
| 225 | + handle.remove() |
| 226 | + self.hook_handles = [] |
| 227 | + |
| 228 | + def gather_outputs(self, output): |
| 229 | + """Gather sharded outputs from all ranks and reconstruct the full tensor.""" |
| 230 | + # Handle different output formats (dict, tensor, etc.) |
| 231 | + if isinstance(output, dict): |
| 232 | + gathered_output = {} |
| 233 | + for key, value in output.items(): |
| 234 | + if isinstance(value, torch.Tensor) and value.dim() > 1: |
| 235 | + # Gather logits or other sequence-sharded tensors |
| 236 | + gathered_value = self.gather_tensor(value) |
| 237 | + gathered_output[key] = gathered_value |
| 238 | + else: |
| 239 | + gathered_value = value.clone() |
| 240 | + dist.all_reduce( |
| 241 | + gathered_value, op=dist.ReduceOp.SUM, group=self.process_group |
| 242 | + ) |
| 243 | + gathered_output[key] = gathered_value |
| 244 | + return gathered_output |
| 245 | + if isinstance(output, torch.Tensor): |
| 246 | + return self.gather_tensor(output) |
| 247 | + |
| 248 | + return output |
| 249 | + |
| 250 | + def gather_tensor(self, tensor): |
| 251 | + """Gather a sharded tensor from all ranks.""" |
| 252 | + # Prepare tensors for all_gather |
| 253 | + world_size = self.local_world_size |
| 254 | + |
| 255 | + # Create list to store tensors from all ranks |
| 256 | + gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] |
| 257 | + |
| 258 | + # All-gather operation |
| 259 | + dist.all_gather(gathered_tensors, tensor, group=self.process_group) |
| 260 | + |
| 261 | + # Concatenate along sequence dimension (typically dim=1) |
| 262 | + if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]: |
| 263 | + # Simple concatenation for standard sharding |
| 264 | + return torch.cat(gathered_tensors, dim=1) |
| 265 | + |
| 266 | + if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: |
| 267 | + # Each rank has a pattern of (rank, world_size*2-rank-1) |
| 268 | + reconstituted_tensors = [None] * (world_size * 2) |
| 269 | + |
| 270 | + # First, split each gathered tensor into its two chunks |
| 271 | + for rank, gathered_tensor in enumerate(gathered_tensors): |
| 272 | + # Each tensor contains two chunks in the sequence dimension |
| 273 | + chunk_size = gathered_tensor.size(1) // 2 |
| 274 | + chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1) |
| 275 | + |
| 276 | + # Place chunks in their original positions |
| 277 | + reconstituted_tensors[rank] = chunk1 |
| 278 | + reconstituted_tensors[world_size * 2 - rank - 1] = chunk2 |
| 279 | + |
| 280 | + # Concatenate the reconstituted tensors in the correct order |
| 281 | + return torch.cat(reconstituted_tensors, dim=1) |
| 282 | + |
| 283 | + # Otherwise, RingAttnFunc.BATCH_STRIPE |
| 284 | + # In striping, each rank has every world_size-th slice |
| 285 | + batch_size = tensor.size(0) |
| 286 | + hidden_dim = tensor.size(-1) |
| 287 | + |
| 288 | + # First, determine the full sequence length |
| 289 | + total_seq_len = 0 |
| 290 | + for t in gathered_tensors: |
| 291 | + total_seq_len += t.size(1) |
| 292 | + |
| 293 | + # Create a tensor to hold the unstriped result |
| 294 | + result = torch.zeros( |
| 295 | + batch_size, |
| 296 | + total_seq_len, |
| 297 | + hidden_dim, |
| 298 | + dtype=tensor.dtype, |
| 299 | + device=tensor.device, |
| 300 | + ) |
| 301 | + |
| 302 | + # For each rank's tensor, distribute its slices to the correct positions |
| 303 | + for rank, gathered_tensor in enumerate(gathered_tensors): |
| 304 | + # The rank's tensor contains every world_size-th slice |
| 305 | + # starting from its rank position |
| 306 | + seq_len = gathered_tensor.size(1) |
| 307 | + for i in range(seq_len): |
| 308 | + # Calculate the position in the full tensor |
| 309 | + pos = i * world_size + rank |
| 310 | + if pos < total_seq_len: |
| 311 | + result[:, pos] = gathered_tensor[:, i] |
| 312 | + |
| 313 | + return result |
0 commit comments