You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Ulysses/ALST integration with HF Accelerate:
- Allow `UlyssesSPAttentionHF.register_with_transformers` to get a
`model` obj as an argument, to match HF accelerate's workflow
- Fix existing Ulysses' tests to tests z2 instead of z1
- Improve documentation
- Add a defensive check
The HF Accelerate PR that depends on this PR is here
huggingface/accelerate#3817
---------
Signed-off-by: Stas Bekman <[email protected]>
If more tokens need to be consumed per step use the gradient accumulation feature.
462
476
477
+
Ulysses expects the following dict keys in each DL batch (`dl->iter->next`):
478
+
- `input_ids`
479
+
- `position_ids`
480
+
- `labels`
481
+
482
+
Additional entries can be present.
483
+
484
+
The tensors are expected to be of shape: `[batch_size, seqlen, ...]`
485
+
486
+
The sharding happens on the seqlen (1st) dimension for all tensors in the batch, any non-tensor entries get copied to all ranks.
487
+
488
+
`attention_mask` isn't used by Ulysses, because it's typically too large when it's 4D, and position_ids is just 1D, therefore it's much much smaller and consumes little GPU memory.
489
+
463
490
Arguments:
464
491
- `dl`: an existing DataLoader object to wrap
465
492
- `sp_rank`: SP rank
@@ -469,10 +496,6 @@ def __init__(
469
496
470
497
Returns:
471
498
Another DataLoader object
472
-
473
-
Here are the current assumptions on the inputs fetched by dl->iter->next
474
-
- the batch is a dict with at least the keys: `input_ids`, `labels`, `position_ids` - but can have any additional keys necessary.
475
-
- the tensor values get sharded, the non-tensor values are passed along as is
476
499
"""
477
500
478
501
self.dl=dl
@@ -515,6 +538,9 @@ def refill(self):
515
538
forkinbatch.keys():
516
539
iftorch.is_tensor(batch[k]):
517
540
batch[k] =batch[k].to(self.device)
541
+
ifseqlen!=batch[k].shape[1]:
542
+
raiseValueError(
543
+
f"{k}'s shape {batch[k].shape} must match input_ids's shape {batch['input_ids'].shape}")
0 commit comments