Skip to content

Commit 0e51e09

Browse files
Add getter APIs for TP/PP/DP ranks in DeepSpeedEngine (#7427)
Thanks again for giving opportunity for improving this Community! This PR is from Issue #7423. 1) Motivation To improve compatibility with low-level profiling tools (e.g., NVIDIA CUPTI or DCGM), it can be useful to expose parallelism-specific rank (tensor/pipeline/data) at the engine level. 2) Changes I Added three getter methods to DeepSpeedEngine: - get_tensor_parallel_rank() - get_pipeline_parallel_rank() - get_data_parallel_rank() Thank you for reviewing this contribution! --------- Signed-off-by: WoosungMyung <dntjd517@naver.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
1 parent e1560d8 commit 0e51e09

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

deepspeed/runtime/engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,15 @@ def random_ltd_initialize(self):
730730
raise ValueError(f'not yet support')
731731
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)
732732

733+
def get_data_parallel_rank(self):
734+
return groups.get_data_parallel_rank()
735+
736+
def get_tensor_parallel_rank(self):
737+
return groups.get_tensor_model_parallel_rank()
738+
739+
def get_model_parallel_rank(self):
740+
return groups.get_model_parallel_rank()
741+
733742
def get_sequence_parallel_group(self):
734743
return self.seq_parallel_group
735744

deepspeed/runtime/pipe/engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ def is_last_stage(self):
535535
"""True if this process is in the last stage in the pipeline."""
536536
return self.stage_id == self.num_stages - 1
537537

538+
def get_pipeline_parallel_rank(self):
539+
return self.stage_id
540+
538541
def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None):
539542
if reduce is None:
540543
return outputs

0 commit comments

Comments
 (0)