|
1 | 1 | """Module for Axolotl trainer sequence parallelism mixin"""
|
2 | 2 |
|
3 | 3 | import logging
|
4 |
| -from typing import Any |
5 | 4 |
|
6 |
| -import torch |
7 | 5 | import torch.distributed as dist
|
8 |
| -import torch.nn.functional as F |
9 | 6 | from datasets import Dataset
|
10 |
| -from torch import nn |
11 | 7 | from torch.utils.data import DistributedSampler, Sampler
|
12 | 8 |
|
13 | 9 | from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
14 | 10 |
|
15 | 11 | LOG = logging.getLogger(__name__)
|
16 | 12 |
|
17 |
| -try: |
18 |
| - from ring_flash_attn import update_ring_flash_attn_params |
19 |
| -except ImportError: |
20 |
| - # We pass silently here, but raise an ImportError in our Axolotl config validation |
21 |
| - # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. |
22 |
| - pass |
23 |
| - |
24 | 13 |
|
25 | 14 | class SequenceParallelMixin:
|
26 | 15 | """
|
27 | 16 | Mixin class for sequence parallelism support in trainers.
|
28 | 17 |
|
29 | 18 | This mixin provides functionality for handling sequence parallelism,
|
30 |
| - including creating appropriate samplers, managing data partitioning, |
31 |
| - and updating ring flash attention parameters during training. |
| 19 | + specifically for creating appropriate data samplers. |
32 | 20 | """
|
33 | 21 |
|
34 | 22 | args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
@@ -99,84 +87,3 @@ def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
99 | 87 | return self._create_sequence_parallel_sampler(
|
100 | 88 | eval_dataset, shuffle=False, is_eval=True
|
101 | 89 | )
|
102 |
| - |
103 |
| - def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]): |
104 |
| - """ |
105 |
| - Calculate the cu_seqlens for the current forward pass and pass the value to |
106 |
| - the substituted ring_flash_attn. This is accomplished by using the passed |
107 |
| - `input_ids`. |
108 |
| -
|
109 |
| - Args: |
110 |
| - inputs: Current batch of inputs. |
111 |
| - """ |
112 |
| - # At this point, inputs should already be partitioned by the sequence |
113 |
| - # parallel data collator |
114 |
| - batch_size = inputs["input_ids"].shape[0] |
115 |
| - seq_len = inputs["input_ids"].shape[1] |
116 |
| - packed_seq_lens = [seq_len] * batch_size |
117 |
| - |
118 |
| - # Calculate the full sequence length across all GPUs in this SP group |
119 |
| - total_seq_len = seq_len * self.args.sequence_parallel_degree |
120 |
| - |
121 |
| - cu_seqlens = torch.cumsum( |
122 |
| - torch.tensor( |
123 |
| - packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 |
124 |
| - ), |
125 |
| - dim=-1, |
126 |
| - dtype=torch.int32, |
127 |
| - ) |
128 |
| - cu_seqlens = F.pad( |
129 |
| - F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len |
130 |
| - ) |
131 |
| - |
132 |
| - update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) |
133 |
| - |
134 |
| - def training_step( |
135 |
| - self, |
136 |
| - model: nn.Module, |
137 |
| - inputs: dict[str, torch.Tensor | Any], |
138 |
| - num_items_in_batch: int | None = None, |
139 |
| - ) -> torch.Tensor: |
140 |
| - """ |
141 |
| - Perform a training step on a batch of inputs. Overrides the |
142 |
| - `transformers.trainer.Trainer` method to handle sequence parallelism if |
143 |
| - enabled. |
144 |
| -
|
145 |
| - Args: |
146 |
| - model: Model to perform training step for. |
147 |
| - inputs: Dictionary mapping. |
148 |
| - """ |
149 |
| - # Set up sequence parallelism for this step if enabled |
150 |
| - if self.args.sequence_parallel_degree > 1: |
151 |
| - self._update_ring_flash_attn_params(inputs) |
152 |
| - |
153 |
| - # Proceed with normal training step |
154 |
| - return super().training_step(model, inputs, num_items_in_batch) # type: ignore |
155 |
| - |
156 |
| - def prediction_step( |
157 |
| - self, |
158 |
| - model: nn.Module, |
159 |
| - inputs: dict[str, torch.Tensor | Any], |
160 |
| - prediction_loss_only: bool, |
161 |
| - ignore_keys: list[str] | None = None, |
162 |
| - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: |
163 |
| - """ |
164 |
| - Perform a prediction step on a batch of inputs. Overrides the |
165 |
| - `transformers.trainer.Trainer` method to handle sequence parallelism if |
166 |
| - enabled. |
167 |
| -
|
168 |
| - Args: |
169 |
| - model: Model to perform prediction step for. |
170 |
| - inputs: Dictionary mapping of inputs. |
171 |
| - prediction_loss_only: Whether to return only the loss. |
172 |
| - ignore_keys: Keys to ignore in the inputs. |
173 |
| -
|
174 |
| - Returns: |
175 |
| - Tuple of (loss, logits, labels). |
176 |
| - """ |
177 |
| - # Set up sequence parallelism for this prediction step if enabled |
178 |
| - if self.args.sequence_parallel_degree > 1: |
179 |
| - self._update_ring_flash_attn_params(inputs) |
180 |
| - |
181 |
| - # Proceed with normal prediction step |
182 |
| - return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore |
0 commit comments