Skip to content

Commit 433e3c7

Browse files
authored
ulysses mpu: additional api (#7649)
It looks like `save_checkpoint` expects `get_model_parallel_*` API in the `mpu` object. So adding it to the Ulysses slim mpu version. This solves this problem in HF Trainer: ``` [rank1]: File "/code/users/stas/github/transformers-alst-integration/src/transformers/trainer.py", line 3248, in _save_optimizer_and_scheduler [rank1]: self.model_wrapped.save_checkpoint(output_dir) [rank1]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/engine.py", line 3497, in save_checkpoint [rank1]: self._save_checkpoint(save_dir, [rank1]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/engine.py", line 3709, in _save_checkpoint [rank1]: save_path = self._get_ckpt_name(save_dir, tag) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/engine.py", line 3039, in _get_ckpt_name [rank1]: mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: AttributeError: module 'deepspeed.runtime.sequence_parallel.parallel_state_sp' has no attribute 'get_model_parallel_rank'. Did you mean: 'get_sequence_parallel_rank'? ``` Signed-off-by: Stas Bekman <[email protected]>
1 parent 706f6e8 commit 433e3c7

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

deepspeed/runtime/sequence_parallel/parallel_state_sp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,9 @@ def get_sequence_parallel_rank():
8888
def get_sequence_data_parallel_rank():
8989
"""Return my rank for the sequence data parallel group."""
9090
return dist.get_rank(group=get_sequence_data_parallel_group())
91+
92+
93+
# since we only have 1 additional dimension over DP, we can just alias MP with SP
94+
get_model_parallel_rank = get_sequence_parallel_rank
95+
get_model_parallel_world_size = get_sequence_parallel_world_size
96+
get_model_parallel_group = get_sequence_parallel_group

0 commit comments

Comments
 (0)