Skip to content

Commit 02da373

Browse files
authored
ALST/UlyssesSP: more intuitive API wrt variable seqlen (#7656)
As I was integrating ALST/Ulysses SP into HF Accelerate/Trainer I noticed that the initial `UlyssesSPAttentionHF.register_with_transformers` API was a bit inflexible/confusing wrt variable seqlen. This PR deprecates the misleading `max_length` arg name, replaces it with `seq_length` and makes the latter optional if `seq_length_is_variable` is True. Updated tests and docs. Signed-off-by: Stas Bekman <[email protected]>
1 parent 433e3c7 commit 02da373

File tree

3 files changed

+75
-20
lines changed

3 files changed

+75
-20
lines changed

deepspeed/runtime/sequence_parallel/ulysses_sp.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from collections import defaultdict
3333
from deepspeed.runtime.utils import see_memory_usage
3434
from deepspeed.sequence.layer import _DimZeroAllToAll
35+
from deepspeed.utils.logging import logger
3536
from einops import rearrange
3637
from packaging import version
3738
from torch import Tensor
@@ -68,15 +69,15 @@ class UlyssesSPAttentionHF(torch.nn.Module):
6869
6970
Arguments:
7071
attn: normal attention implementation from transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS
71-
local_seq_length (int): local sequence length per GPU
72-
global_seq_length (int): actual sequence length
72+
seq_length_is_variable (bool): whether global seqlen may change between batches
73+
local_seq_length (int): local sequence length per GPU or None if seq_length_is_variable is True
74+
global_seq_length (int): actual sequence length or None if seq_length_is_variable is True
7375
batch_size (int): batch size
7476
attn_head_size (int): size of each attention head
7577
attn_head_count (int): total number of attention heads
7678
kv_head_count (int): total number of kv heads
7779
num_hidden_layers (int): total number of layers
7880
process_group (dist.ProcessGroup): Ulysses process group
79-
seq_length_is_variable (bool): whether global seqlen may change between batches
8081
8182
8283
Extras:
@@ -86,26 +87,26 @@ class UlyssesSPAttentionHF(torch.nn.Module):
8687
def __init__(
8788
self,
8889
attn,
89-
local_seq_length: int,
90-
global_seq_length: int,
9190
batch_size: int,
9291
attn_head_count: int,
9392
attn_head_size: int,
9493
kv_head_count: int,
9594
num_hidden_layers: int,
9695
process_group: dist.ProcessGroup,
9796
seq_length_is_variable: bool = False,
97+
local_seq_length: int = None,
98+
global_seq_length: int = None,
9899
) -> None:
99100
super().__init__()
100101
self.attn = attn
101102
self.process_group = process_group
102103
self.world_size = dist.get_world_size(process_group)
103104
self.sp_rank = dist.get_rank(process_group)
104105

105-
self.local_seq_length = local_seq_length
106-
self.global_seq_length = global_seq_length
107106
self.batch_size = batch_size
108107
self.seq_length_is_variable = seq_length_is_variable
108+
self.local_seq_length = local_seq_length
109+
self.global_seq_length = global_seq_length
109110

110111
self.attn_head_size = attn_head_size
111112
self.attn_head_count = attn_head_count
@@ -138,6 +139,12 @@ def __init__(
138139
f"KV attention head count {self.global_kv_head_count} is not divisible by SP size {self.world_size} or"
139140
" vice versa")
140141

142+
if self.seq_length_is_variable:
143+
# the self.required_*_shape depending on the following will get updated in `forward`
144+
# use 1 as a placeholder for dim=0 to keep torch.Size happy
145+
local_seq_length = 1
146+
global_seq_length = 1
147+
141148
# [sl_l bs hc hs]
142149
self.required_query_shape = torch.Size([local_seq_length, batch_size, attn_head_count, attn_head_size])
143150
self.required_key_value_shape = torch.Size([local_seq_length, batch_size, kv_head_count, attn_head_size])
@@ -239,8 +246,8 @@ def forward(
239246
# print_rank0(f"{key.shape=}")
240247
# print_rank0(f"{value.shape=}")
241248
# print_rank0(f"{self.required_input_shape=}")
242-
current_local_seq_length = query.shape[2]
243-
if self.seq_length_is_variable and current_local_seq_length != self.required_query_shape[0]:
249+
if self.seq_length_is_variable:
250+
current_local_seq_length = query.shape[2]
244251
self.local_seq_length = current_local_seq_length
245252
self.global_seq_length = current_local_seq_length * self.world_size
246253
# update the required seqlen shapes
@@ -340,9 +347,11 @@ def register_with_transformers(
340347
model_name_or_path,
341348
core_attn_implementation,
342349
sequence_parallel_size,
343-
max_length,
344350
micro_batch_size,
351+
seq_length=None,
345352
seq_length_is_variable=True,
353+
# deprecated
354+
max_length=None,
346355
):
347356
"""
348357
Register "ulysses" attn_implementation with HF transformers and return mpu (Megatron-LM-style parallel state groups object).
@@ -352,14 +361,26 @@ def register_with_transformers(
352361
- model_name_or_path (object or str): model object, or HF hub model name, or model's local path
353362
- core_attn_implementation (str): which attention to use: flash_attention_2 or flash_attention_3 or sdpa
354363
- sequence_parallel_size (int): sequence parallelism dimension (if 1 it's disabled)
355-
- max_length (int): actual global sequence length
356364
- micro_batch_size (int): micro batch size
365+
- seq_length (int): set this argument if the sequence length is fixed in all batches
357366
- seq_length_is_variable (bool): whether global seqlen may change between batches an optimization flag - the default is `True`
367+
- max_length (int): actual global sequence length - this argument is deprecated - use `seq_length` instead
358368
359369
"""
360370
if sequence_parallel_size == 1:
361371
return None
362372

373+
if max_length is not None:
374+
logger.warning(
375+
"The 'max_length` argument is deprecated and will be eventually removed, please use `seq_length` instead"
376+
)
377+
if seq_length is None and max_length is not None:
378+
seq_length = max_length
379+
if not seq_length_is_variable and seq_length is None:
380+
raise ValueError(
381+
"Either `seq_length_is_variable` needs to be `True` or `seq_length` needs to be set to an integer value of the fixed batch size length."
382+
)
383+
363384
from transformers import AutoConfig
364385
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
365386

@@ -390,10 +411,16 @@ def register_with_transformers(
390411
f"{core_attn_implementation} is not a valid attn_implementation. The choices are {ALL_ATTENTION_FUNCTIONS.valid_keys()}"
391412
)
392413
core_attn_function = ALL_ATTENTION_FUNCTIONS[core_attn_implementation]
414+
415+
if seq_length_is_variable:
416+
local_seq_length = None
417+
global_seq_length = None
418+
else:
419+
local_seq_length = seq_length // mpu.get_sequence_parallel_world_size()
420+
global_seq_length = seq_length
421+
393422
uattn = UlyssesSPAttentionHF(
394423
attn=core_attn_function,
395-
local_seq_length=max_length // mpu.get_sequence_parallel_world_size(),
396-
global_seq_length=max_length,
397424
batch_size=micro_batch_size,
398425
attn_head_count=hf_model_config.num_attention_heads,
399426
attn_head_size=getattr(hf_model_config, "head_dim",
@@ -402,6 +429,8 @@ def register_with_transformers(
402429
num_hidden_layers=hf_model_config.num_hidden_layers,
403430
process_group=mpu.get_sequence_parallel_group(),
404431
seq_length_is_variable=seq_length_is_variable,
432+
local_seq_length=local_seq_length,
433+
global_seq_length=global_seq_length,
405434
)
406435

407436
def uattn_wrapper(

docs/_tutorials/ulysses-alst-sequence-parallelism.md

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import deepspeed.comm as dist
3636
import torch
3737

3838
model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
39-
max_length = 64
39+
seq_length = 64
4040
sequence_parallel_size = 2
4141
micro_batch_size = 1
4242

@@ -74,8 +74,8 @@ mpu = UlyssesSPAttentionHF.register_with_transformers(
7474
model_name_or_path=model_name_or_path,
7575
core_attn_implementation="sdpa",
7676
sequence_parallel_size=sequence_parallel_size,
77-
max_length=max_length,
7877
micro_batch_size=micro_batch_size,
78+
seq_length=seq_length,
7979
seq_length_is_variable=True,
8080
)
8181

@@ -151,16 +151,42 @@ mpu = UlyssesSPAttentionHF.register_with_transformers(
151151
model_name_or_path=model_name_or_path,
152152
core_attn_implementation="sdpa",
153153
sequence_parallel_size=sequence_parallel_size,
154-
max_length=max_length,
155154
micro_batch_size=micro_batch_size,
155+
seq_length=seq_length,
156156
seq_length_is_variable=True,
157157
)
158158
```
159159

160160
It also creates nccl process groups encapsulated by the `mpu` object it returns.
161161

162+
For the `model_name_or_path` argument you can also pass the already existing HF Transformers `model` object.
163+
162164
`UlyssesSPAttentionHF.register_with_transformers` has to be called before `from_pretrained` is called.
163165

166+
If `seq_length_is_variable` is `True` (which is also the default value), `UlyssesSPAttentionHF` will recalculate the shapes on each `forward` based on the incoming batch's shapes - in which case you don't need to set `seq_length` - you can just skip it like so:
167+
```
168+
mpu = UlyssesSPAttentionHF.register_with_transformers(
169+
model_name_or_path=model_name_or_path,
170+
core_attn_implementation="sdpa",
171+
sequence_parallel_size=sequence_parallel_size,
172+
micro_batch_size=micro_batch_size,
173+
seq_length_is_variable=True,
174+
)
175+
```
176+
177+
If, however, all your batches have an identical sequence length, then you'd save a few microseconds per run with using the `seq_length_is_variable=False` code path, which will pre-measure all shapes once and re-use them in all runs:
178+
179+
```
180+
mpu = UlyssesSPAttentionHF.register_with_transformers(
181+
[...]
182+
seq_length=seq_length,
183+
seq_length_is_variable=False,
184+
)
185+
```
186+
187+
If you pass `seq_length`, remember that it has to be divisible by `sequence_parallel_size`. And of course, this also applies to all batches, even if you use `seq_length_is_variable=True`.
188+
189+
164190
### UlyssesSPDataLoaderAdapter
165191

166192
```python
@@ -173,9 +199,9 @@ dl = UlyssesSPDataLoaderAdapter(
173199
)
174200
```
175201

176-
This takes an existing DataLoader object and returns a new one that will shard the batches on the sequence dimension and synchronize all GPUs of the replica to return only its corresponding shard.
202+
This takes an existing DataLoader object and returns a new one that will shard the batches on the sequence dimension and synchronize all GPUs of the replica to return to each rank only its corresponding sequence shard.
177203

178-
It also takes care of pre-shifting labels and replacing `labels` with `shift_labels` in the batch.
204+
It also takes care of replacing `labels` with `shift_labels` in the batch, by pre-shifting labels, which is crucial for the correct loss calculation when using Ulysses sequence parallelism.
179205

180206
### Loss averaging
181207

tests/unit/ulysses_alst/test_ulysses_sp_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TestUlyssesSPHF(DistributedTest):
3636
def test_ulysses_sp_hf(self, zero_stage):
3737
model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
3838
#model_name_or_path = 'Felladrin/Llama-160M-Chat-v1'
39-
max_length = 64
39+
seq_length = 64
4040
sequence_parallel_size = self.world_size
4141
micro_batch_size = 1
4242

@@ -105,8 +105,8 @@ def collate_fn(batch):
105105
model_name_or_path=model_name_or_path,
106106
core_attn_implementation="sdpa",
107107
sequence_parallel_size=sequence_parallel_size,
108-
max_length=max_length,
109108
micro_batch_size=micro_batch_size,
109+
seq_length=seq_length,
110110
seq_length_is_variable=True,
111111
)
112112

0 commit comments

Comments
 (0)