@@ -92,13 +92,50 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
9292 destroy_model_parallel ()
9393
9494
95- # 1. Tensor Parallel Test
96- def _test_tensor_parallel_helper (config , rank , size ):
97- initialize_for_megatron (tensor_model_parallel_size = 2 , seed = SEED )
98- tp_group = get_tensor_model_parallel_group ()
99- model = MegatronModel (tp_size = size , tp_group = tp_group ).cuda ()
95+ # Unified parallelism test helper
96+ def _test_parallelism_helper (
97+ config ,
98+ rank ,
99+ size ,
100+ tensor_model_parallel_size = 1 ,
101+ context_parallel_size = 1 ,
102+ use_rank_in_seed = False ,
103+ ):
104+ """
105+ Unified helper for testing different parallelism configurations.
106+ Args:
107+ config: Quantization config to test
108+ rank: Current rank in distributed setup
109+ size: Total number of processes
110+ tensor_model_parallel_size: Size of tensor model parallel group (default: 1)
111+ context_parallel_size: Size of context parallel group (default: 1)
112+ use_rank_in_seed: Whether to add rank to seed for different data across ranks (default: False)
113+ """
114+ seed = SEED + rank if use_rank_in_seed else SEED
115+ initialize_for_megatron (
116+ tensor_model_parallel_size = tensor_model_parallel_size ,
117+ context_parallel_size = context_parallel_size ,
118+ seed = seed ,
119+ )
100120
101- data_tensor_context_parallel_test_helper (model , config , tp_group = tp_group )
121+ # Determine if we need tp_group and dp_group
122+ tp_group = get_tensor_model_parallel_group () if tensor_model_parallel_size > 1 else None
123+ dp_group = get_data_parallel_group (with_context_parallel = True )
124+
125+ # Create model with appropriate parallelism settings
126+ model = MegatronModel (
127+ tp_size = tensor_model_parallel_size ,
128+ cp_size = context_parallel_size ,
129+ tp_group = tp_group ,
130+ ).cuda ()
131+
132+ # Call the test helper with appropriate groups
133+ data_tensor_context_parallel_test_helper (
134+ model ,
135+ config ,
136+ dp_group = dp_group ,
137+ tp_group = tp_group ,
138+ )
102139
103140
104141@pytest .mark .parametrize (
@@ -115,18 +152,12 @@ def _test_tensor_parallel_helper(config, rank, size):
115152)
116153def test_tensor_parallel (need_2_gpus , config ):
117154 spawn_multiprocess_job (
118- size = 2 , job = partial (_test_tensor_parallel_helper , config ), backend = "nccl"
155+ size = 2 ,
156+ job = partial (_test_parallelism_helper , config , tensor_model_parallel_size = 2 ),
157+ backend = "nccl" ,
119158 )
120159
121160
122- # 2. Data Parallel Test
123- def _test_data_parallel_helper (config , rank , size ):
124- initialize_for_megatron (seed = SEED + rank ) # modify seed so data is different across ranks
125- model = MegatronModel ().cuda ()
126-
127- data_tensor_context_parallel_test_helper (model , config , dp_group = get_data_parallel_group ())
128-
129-
130161@pytest .mark .parametrize (
131162 "config" ,
132163 [
@@ -140,18 +171,10 @@ def _test_data_parallel_helper(config, rank, size):
140171 ],
141172)
142173def test_data_parallel (need_2_gpus , config ):
143- spawn_multiprocess_job (size = 2 , job = partial (_test_data_parallel_helper , config ), backend = "nccl" )
144-
145-
146- # 3. Context Parallel Test
147- def _test_context_parallel_helper (config , rank , size ):
148- initialize_for_megatron (
149- context_parallel_size = size , seed = SEED + rank
150- ) # modify seed so data is different across ranks
151- model = MegatronModel (cp_size = size ).cuda ()
152-
153- data_tensor_context_parallel_test_helper (
154- model , config , dp_group = get_data_parallel_group (with_context_parallel = True )
174+ spawn_multiprocess_job (
175+ size = 2 ,
176+ job = partial (_test_parallelism_helper , config , use_rank_in_seed = True ),
177+ backend = "nccl" ,
155178 )
156179
157180
@@ -169,21 +192,11 @@ def _test_context_parallel_helper(config, rank, size):
169192)
170193def test_context_parallel (need_2_gpus , config ):
171194 spawn_multiprocess_job (
172- size = 2 , job = partial (_test_context_parallel_helper , config ), backend = "nccl"
173- )
174-
175-
176- # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
177- def _test_data_tensor_context_parallel_helper (config , rank , size ):
178- initialize_for_megatron (tensor_model_parallel_size = 2 , context_parallel_size = 2 , seed = SEED + rank )
179- tp_group = get_tensor_model_parallel_group ()
180- model = MegatronModel (tp_size = 2 , cp_size = 2 , tp_group = tp_group ).cuda ()
181-
182- data_tensor_context_parallel_test_helper (
183- model ,
184- config ,
185- dp_group = get_data_parallel_group (with_context_parallel = True ),
186- tp_group = tp_group ,
195+ size = 2 ,
196+ job = partial (
197+ _test_parallelism_helper , config , context_parallel_size = 2 , use_rank_in_seed = True
198+ ),
199+ backend = "nccl" ,
187200 )
188201
189202
@@ -201,7 +214,15 @@ def _test_data_tensor_context_parallel_helper(config, rank, size):
201214)
202215def test_data_tensor_context_parallel (need_8_gpus , config ):
203216 spawn_multiprocess_job (
204- size = 8 , job = partial (_test_data_tensor_context_parallel_helper , config ), backend = "nccl"
217+ size = 8 ,
218+ job = partial (
219+ _test_parallelism_helper ,
220+ config ,
221+ tensor_model_parallel_size = 2 ,
222+ context_parallel_size = 2 ,
223+ use_rank_in_seed = True ,
224+ ),
225+ backend = "nccl" ,
205226 )
206227
207228
0 commit comments