Skip to content

Commit 64c0052

Browse files
authored
Ulysses HF Accelerate integration (#7638)
Ulysses/ALST integration with HF Accelerate: - Allow `UlyssesSPAttentionHF.register_with_transformers` to get a `model` obj as an argument, to match HF accelerate's workflow - Fix existing Ulysses' tests to tests z2 instead of z1 - Improve documentation - Add a defensive check The HF Accelerate PR that depends on this PR is here huggingface/accelerate#3817 --------- Signed-off-by: Stas Bekman <[email protected]>
1 parent 9c86cd9 commit 64c0052

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

deepspeed/runtime/sequence_parallel/ulysses_sp.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,16 @@ def register_with_transformers(
345345
seq_length_is_variable=True,
346346
):
347347
"""
348-
Register "ulysses" attn_implementation with HF transformers and return mpu (Megatron-LM-style parallel state object).
349-
If sequence_parallel_size==1 do nothng and return None.
348+
Register "ulysses" attn_implementation with HF transformers and return mpu (Megatron-LM-style parallel state groups object).
349+
If sequence_parallel_size==1 do nothing and return None.
350+
351+
Args:
352+
- model_name_or_path (object or str): model object, or HF hub model name, or model's local path
353+
- core_attn_implementation (str): which attention to use: flash_attention_2 or flash_attention_3 or sdpa
354+
- sequence_parallel_size (int): sequence parallelism dimension (if 1 it's disabled)
355+
- max_length (int): actual global sequence length
356+
- micro_batch_size (int): micro batch size
357+
- seq_length_is_variable (bool): whether global seqlen may change between batches an optimization flag - the default is `True`
350358
351359
"""
352360
if sequence_parallel_size == 1:
@@ -359,8 +367,14 @@ def register_with_transformers(
359367

360368
mpu.initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size)
361369

362-
# we don't have the model yet at this stage
363-
hf_model_config = AutoConfig.from_pretrained(model_name_or_path)
370+
from transformers import PreTrainedModel
371+
if isinstance(model_name_or_path, PreTrainedModel):
372+
# we already have the model
373+
hf_model_config = model_name_or_path.config
374+
else:
375+
# if we don't have the model yet at this stage
376+
hf_model_config = AutoConfig.from_pretrained(model_name_or_path)
377+
364378
supported_attn_implementation = ["flash_attention_2", "flash_attention_3", "sdpa"]
365379
if core_attn_implementation not in supported_attn_implementation:
366380
# notes on the excluded ones:
@@ -460,6 +474,19 @@ def __init__(
460474
461475
If more tokens need to be consumed per step use the gradient accumulation feature.
462476
477+
Ulysses expects the following dict keys in each DL batch (`dl->iter->next`):
478+
- `input_ids`
479+
- `position_ids`
480+
- `labels`
481+
482+
Additional entries can be present.
483+
484+
The tensors are expected to be of shape: `[batch_size, seqlen, ...]`
485+
486+
The sharding happens on the seqlen (1st) dimension for all tensors in the batch, any non-tensor entries get copied to all ranks.
487+
488+
`attention_mask` isn't used by Ulysses, because it's typically too large when it's 4D, and position_ids is just 1D, therefore it's much much smaller and consumes little GPU memory.
489+
463490
Arguments:
464491
- `dl`: an existing DataLoader object to wrap
465492
- `sp_rank`: SP rank
@@ -469,10 +496,6 @@ def __init__(
469496
470497
Returns:
471498
Another DataLoader object
472-
473-
Here are the current assumptions on the inputs fetched by dl->iter->next
474-
- the batch is a dict with at least the keys: `input_ids`, `labels`, `position_ids` - but can have any additional keys necessary.
475-
- the tensor values get sharded, the non-tensor values are passed along as is
476499
"""
477500

478501
self.dl = dl
@@ -515,6 +538,9 @@ def refill(self):
515538
for k in batch.keys():
516539
if torch.is_tensor(batch[k]):
517540
batch[k] = batch[k].to(self.device)
541+
if seqlen != batch[k].shape[1]:
542+
raise ValueError(
543+
f"{k}'s shape {batch[k].shape} must match input_ids's shape {batch['input_ids'].shape}")
518544
with torch.no_grad():
519545
tensor_list = [
520546
torch.zeros((batch[k].shape[0], seqlens[i]), dtype=batch[k].dtype, device=batch[k].device)

tests/unit/ulysses_alst/test_tiled_compute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def mlp_forward_sequence_tiled_compute(self, x):
9898

9999

100100
@pytest.mark.parametrize("batch_size", [1, 2])
101-
@pytest.mark.parametrize("zero_stage", [1, 3])
101+
@pytest.mark.parametrize("zero_stage", [2, 3])
102102
class TestTiledCompute(DistributedTest):
103103
world_size = 1
104104

@@ -232,7 +232,7 @@ def test_tiled_mlp(self, zero_stage, batch_size):
232232

233233

234234
@pytest.mark.parametrize("batch_size", [1, 2])
235-
@pytest.mark.parametrize("zero_stage", [1, 3])
235+
@pytest.mark.parametrize("zero_stage", [2, 3])
236236
class TestTiledFusedLogitsLoss(DistributedTest):
237237
world_size = 1
238238

tests/unit/ulysses_alst/test_ulysses_sp_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_grad(param, zero_stage):
2929
# return safe_get_full_grad(param)
3030

3131

32-
@pytest.mark.parametrize("zero_stage", [1, 3])
32+
@pytest.mark.parametrize("zero_stage", [2, 3])
3333
class TestUlyssesSPHF(DistributedTest):
3434
world_size = 2
3535

0 commit comments

Comments
 (0)