2121import pytest
2222import torch
2323from mpi4py import MPI
24- from mpi4py .futures import MPIPoolExecutor
2524from utils .util import skip_pre_blackwell
2625
2726import tensorrt_llm
4039 pickle .HIGHEST_PROTOCOL ,
4140)
4241
42+ # needed since we reuse the mpi executor pool, first test running will leak a thread
43+ pytestmark = pytest .mark .threadleak (enabled = False )
44+
4345
4446def fp8_quant (input , scale ):
4547 finfo = torch .finfo (torch .float8_e4m3fn )
@@ -278,22 +280,23 @@ def ref_residual_rms_norm_out_quant_nvfp4(x, res):
278280 marks = skip_pre_blackwell ),
279281 ],
280282)
281- def test_allreduce_fusion_patterns (seq_len , hidden_size , fusion_op ):
283+ @pytest .mark .parametrize ("mpi_pool_executor" , [2 ], indirect = True )
284+ def test_allreduce_fusion_patterns (seq_len , hidden_size , fusion_op ,
285+ mpi_pool_executor ):
282286 torch .manual_seed (0 )
283287 dtype = torch .bfloat16
284- tensor_parallel_size = 2
288+ tensor_parallel_size = mpi_pool_executor . num_workers
285289 x = torch .randn ((seq_len , hidden_size ), dtype = dtype )
286290 residual = torch .randn_like (x )
287291 linear_weight = torch .randn ((hidden_size , hidden_size ), dtype = dtype )
288- with MPIPoolExecutor (max_workers = tensor_parallel_size ) as executor :
289- results = executor .map (
290- run_single_rank ,
291- * zip (* [(tensor_parallel_size , run_allreduce_op , x , residual ,
292- [linear_weight ], hidden_size , dtype , fusion_op )] *
293- tensor_parallel_size ),
294- )
295- for r in results :
296- assert r is True
292+ results = mpi_pool_executor .map (
293+ run_single_rank ,
294+ * zip (* [(tensor_parallel_size , run_allreduce_op , x , residual ,
295+ [linear_weight ], hidden_size , dtype , fusion_op )] *
296+ tensor_parallel_size ),
297+ )
298+ for r in results :
299+ assert r is True
297300
298301
299302@torch .inference_mode ()
@@ -426,13 +429,14 @@ def run_moe_allreduce_op(token_input: torch.Tensor, residual: torch.Tensor,
426429
427430
428431@torch .inference_mode ()
429- def test_moe_allreduce_patterns ():
432+ @pytest .mark .parametrize ("mpi_pool_executor" , [2 ], indirect = True )
433+ def test_moe_allreduce_patterns (mpi_pool_executor ):
430434 torch .manual_seed (42 )
431435
432436 seq_len = 16
433437 hidden_size = 7168
434438 dtype = torch .bfloat16
435- tensor_parallel_size = 2
439+ tensor_parallel_size = mpi_pool_executor . num_workers
436440 num_global_experts = 4
437441
438442 # [num_token, 7168]
@@ -448,15 +452,14 @@ def test_moe_allreduce_patterns():
448452 residual = torch .randn_like (token_input )
449453
450454 l0_weight = torch .randn ((hidden_size , hidden_size ), dtype = dtype )
451- with MPIPoolExecutor (max_workers = tensor_parallel_size ) as executor :
452- results = executor .map (
453- run_moe_single_rank ,
454- * zip (* [(tensor_parallel_size , run_moe_allreduce_op , token_input ,
455- residual , active_experts_token_input , scale , l0_weight )] *
456- tensor_parallel_size ),
457- )
458- for r in results :
459- assert r is True
455+ results = mpi_pool_executor .map (
456+ run_moe_single_rank ,
457+ * zip (* [(tensor_parallel_size , run_moe_allreduce_op , token_input ,
458+ residual , active_experts_token_input , scale , l0_weight )] *
459+ tensor_parallel_size ),
460+ )
461+ for r in results :
462+ assert r is True
460463
461464
462465def run_moe_finalize_single_rank (tensor_parallel_size , single_rank_forward_func ,
@@ -544,13 +547,14 @@ def run_moe_finalize_allreduce_op(
544547
545548
546549@torch .inference_mode ()
547- def test_moe_finalize_allreduce_patterns ():
550+ @pytest .mark .parametrize ("mpi_pool_executor" , [2 ], indirect = True )
551+ def test_moe_finalize_allreduce_patterns (mpi_pool_executor ):
548552 torch .manual_seed (42 )
549553
550554 seq_len = 16
551555 hidden_size = 7168
552556 dtype = torch .bfloat16
553- tensor_parallel_size = 2
557+ tensor_parallel_size = mpi_pool_executor . num_workers
554558 top_k = 8
555559
556560 shared_expert_output = torch .randn ((seq_len , hidden_size ), dtype = dtype )
@@ -562,13 +566,11 @@ def test_moe_finalize_allreduce_patterns():
562566 dtype = torch .int32 )
563567 residual = torch .randn_like (shared_expert_output )
564568
565- with MPIPoolExecutor (max_workers = tensor_parallel_size ) as executor :
566- results = executor .map (
567- run_moe_finalize_single_rank ,
568- * zip (* [(tensor_parallel_size , run_moe_finalize_allreduce_op ,
569- fc2_output , residual , shared_expert_output ,
570- expanded_idx_to_permuted_idx , scale )] *
571- tensor_parallel_size ),
572- )
573- for r in results :
574- assert r is True
569+ results = mpi_pool_executor .map (
570+ run_moe_finalize_single_rank ,
571+ * zip (* [(tensor_parallel_size , run_moe_finalize_allreduce_op , fc2_output ,
572+ residual , shared_expert_output , expanded_idx_to_permuted_idx ,
573+ scale )] * tensor_parallel_size ),
574+ )
575+ for r in results :
576+ assert r is True
0 commit comments