File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -558,7 +558,8 @@ def test_separate_all_gather_group():
558558
559559 # --- Path 2: explicit ProcessGroupCollection ---
560560 Utils .initialize_model_parallel (context_parallel_size = world_size )
561- dp_cp_ranks = ps .get_data_parallel_global_ranks (with_context_parallel = True )
561+ dp_cp_group = ps .get_data_parallel_group (with_context_parallel = True )
562+ dp_cp_ranks = torch .distributed .get_process_group_ranks (dp_cp_group )
562563 dp_cp_ag_group = torch .distributed .new_group (ranks = dp_cp_ranks , backend = 'nccl' )
563564
564565 pg_collection = ProcessGroupCollection .use_mpu_process_groups ()
@@ -584,8 +585,10 @@ def test_expert_all_gather_group():
584585 )
585586
586587 # Get ranks for both regular and expert AG groups
587- dp_cp_ranks = ps .get_data_parallel_global_ranks (with_context_parallel = True )
588- expt_dp_ranks = ps .get_expert_data_parallel_global_ranks ()
588+ dp_cp_group = ps .get_data_parallel_group (with_context_parallel = True )
589+ dp_cp_ranks = torch .distributed .get_process_group_ranks (dp_cp_group )
590+ expt_dp_group = ps .get_expert_data_parallel_group ()
591+ expt_dp_ranks = torch .distributed .get_process_group_ranks (expt_dp_group )
589592
590593 # Create AG groups for both regular and expert parameters
591594 dp_cp_ag_group = torch .distributed .new_group (ranks = dp_cp_ranks , backend = 'nccl' )
You can’t perform that action at this time.
0 commit comments