Skip to content

Commit 3df8ef5

Browse files
committed
Fix test: use get_process_group_ranks instead of non-existent helper
Signed-off-by: jeffnvidia <jmahou@nvidia.com>
1 parent 97d9da8 commit 3df8ef5

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/unit_tests/test_parallel_state.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,8 @@ def test_separate_all_gather_group():
558558

559559
# --- Path 2: explicit ProcessGroupCollection ---
560560
Utils.initialize_model_parallel(context_parallel_size=world_size)
561-
dp_cp_ranks = ps.get_data_parallel_global_ranks(with_context_parallel=True)
561+
dp_cp_group = ps.get_data_parallel_group(with_context_parallel=True)
562+
dp_cp_ranks = torch.distributed.get_process_group_ranks(dp_cp_group)
562563
dp_cp_ag_group = torch.distributed.new_group(ranks=dp_cp_ranks, backend='nccl')
563564

564565
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
@@ -584,8 +585,10 @@ def test_expert_all_gather_group():
584585
)
585586

586587
# Get ranks for both regular and expert AG groups
587-
dp_cp_ranks = ps.get_data_parallel_global_ranks(with_context_parallel=True)
588-
expt_dp_ranks = ps.get_expert_data_parallel_global_ranks()
588+
dp_cp_group = ps.get_data_parallel_group(with_context_parallel=True)
589+
dp_cp_ranks = torch.distributed.get_process_group_ranks(dp_cp_group)
590+
expt_dp_group = ps.get_expert_data_parallel_group()
591+
expt_dp_ranks = torch.distributed.get_process_group_ranks(expt_dp_group)
589592

590593
# Create AG groups for both regular and expert parameters
591594
dp_cp_ag_group = torch.distributed.new_group(ranks=dp_cp_ranks, backend='nccl')

0 commit comments

Comments
 (0)