Skip to content

Commit 94dc97a

Browse files
authored
[feat][test] reuse MPI pool executor across tests (NVIDIA#5566)
Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
1 parent 6000380 commit 94dc97a

File tree

4 files changed

+130
-110
lines changed

4 files changed

+130
-110
lines changed

tests/unittest/_torch/multi_gpu/test_allreduce.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pytest
2222
import torch
2323
from mpi4py import MPI
24-
from mpi4py.futures import MPIPoolExecutor
2524
from utils.util import skip_pre_blackwell
2625

2726
import tensorrt_llm
@@ -40,6 +39,9 @@
4039
pickle.HIGHEST_PROTOCOL,
4140
)
4241

42+
# needed since we reuse the mpi executor pool, first test running will leak a thread
43+
pytestmark = pytest.mark.threadleak(enabled=False)
44+
4345

4446
def fp8_quant(input, scale):
4547
finfo = torch.finfo(torch.float8_e4m3fn)
@@ -278,22 +280,23 @@ def ref_residual_rms_norm_out_quant_nvfp4(x, res):
278280
marks=skip_pre_blackwell),
279281
],
280282
)
281-
def test_allreduce_fusion_patterns(seq_len, hidden_size, fusion_op):
283+
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
284+
def test_allreduce_fusion_patterns(seq_len, hidden_size, fusion_op,
285+
mpi_pool_executor):
282286
torch.manual_seed(0)
283287
dtype = torch.bfloat16
284-
tensor_parallel_size = 2
288+
tensor_parallel_size = mpi_pool_executor.num_workers
285289
x = torch.randn((seq_len, hidden_size), dtype=dtype)
286290
residual = torch.randn_like(x)
287291
linear_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
288-
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
289-
results = executor.map(
290-
run_single_rank,
291-
*zip(*[(tensor_parallel_size, run_allreduce_op, x, residual,
292-
[linear_weight], hidden_size, dtype, fusion_op)] *
293-
tensor_parallel_size),
294-
)
295-
for r in results:
296-
assert r is True
292+
results = mpi_pool_executor.map(
293+
run_single_rank,
294+
*zip(*[(tensor_parallel_size, run_allreduce_op, x, residual,
295+
[linear_weight], hidden_size, dtype, fusion_op)] *
296+
tensor_parallel_size),
297+
)
298+
for r in results:
299+
assert r is True
297300

298301

299302
@torch.inference_mode()
@@ -426,13 +429,14 @@ def run_moe_allreduce_op(token_input: torch.Tensor, residual: torch.Tensor,
426429

427430

428431
@torch.inference_mode()
429-
def test_moe_allreduce_patterns():
432+
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
433+
def test_moe_allreduce_patterns(mpi_pool_executor):
430434
torch.manual_seed(42)
431435

432436
seq_len = 16
433437
hidden_size = 7168
434438
dtype = torch.bfloat16
435-
tensor_parallel_size = 2
439+
tensor_parallel_size = mpi_pool_executor.num_workers
436440
num_global_experts = 4
437441

438442
# [num_token, 7168]
@@ -448,15 +452,14 @@ def test_moe_allreduce_patterns():
448452
residual = torch.randn_like(token_input)
449453

450454
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
451-
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
452-
results = executor.map(
453-
run_moe_single_rank,
454-
*zip(*[(tensor_parallel_size, run_moe_allreduce_op, token_input,
455-
residual, active_experts_token_input, scale, l0_weight)] *
456-
tensor_parallel_size),
457-
)
458-
for r in results:
459-
assert r is True
455+
results = mpi_pool_executor.map(
456+
run_moe_single_rank,
457+
*zip(*[(tensor_parallel_size, run_moe_allreduce_op, token_input,
458+
residual, active_experts_token_input, scale, l0_weight)] *
459+
tensor_parallel_size),
460+
)
461+
for r in results:
462+
assert r is True
460463

461464

462465
def run_moe_finalize_single_rank(tensor_parallel_size, single_rank_forward_func,
@@ -544,13 +547,14 @@ def run_moe_finalize_allreduce_op(
544547

545548

546549
@torch.inference_mode()
547-
def test_moe_finalize_allreduce_patterns():
550+
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
551+
def test_moe_finalize_allreduce_patterns(mpi_pool_executor):
548552
torch.manual_seed(42)
549553

550554
seq_len = 16
551555
hidden_size = 7168
552556
dtype = torch.bfloat16
553-
tensor_parallel_size = 2
557+
tensor_parallel_size = mpi_pool_executor.num_workers
554558
top_k = 8
555559

556560
shared_expert_output = torch.randn((seq_len, hidden_size), dtype=dtype)
@@ -562,13 +566,11 @@ def test_moe_finalize_allreduce_patterns():
562566
dtype=torch.int32)
563567
residual = torch.randn_like(shared_expert_output)
564568

565-
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
566-
results = executor.map(
567-
run_moe_finalize_single_rank,
568-
*zip(*[(tensor_parallel_size, run_moe_finalize_allreduce_op,
569-
fc2_output, residual, shared_expert_output,
570-
expanded_idx_to_permuted_idx, scale)] *
571-
tensor_parallel_size),
572-
)
573-
for r in results:
574-
assert r is True
569+
results = mpi_pool_executor.map(
570+
run_moe_finalize_single_rank,
571+
*zip(*[(tensor_parallel_size, run_moe_finalize_allreduce_op, fc2_output,
572+
residual, shared_expert_output, expanded_idx_to_permuted_idx,
573+
scale)] * tensor_parallel_size),
574+
)
575+
for r in results:
576+
assert r is True

tests/unittest/_torch/multi_gpu/test_lowprecision_allreduce.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import pytest
2020
import torch
2121
from mpi4py import MPI
22-
from mpi4py.futures import MPIPoolExecutor
2322

2423
from tensorrt_llm._torch.distributed import AllReduceStrategy
2524

@@ -30,6 +29,9 @@
3029
pickle.HIGHEST_PROTOCOL,
3130
)
3231

32+
# needed since we reuse the mpi executor pool, first test running will leak a thread
33+
pytestmark = pytest.mark.threadleak(enabled=False)
34+
3335

3436
def run_single_rank(dtype, strategy, message_size):
3537
import numpy as np
@@ -243,18 +245,20 @@ def test(self, mode="acc"):
243245
[1024 * 1024 * x for x in [2, 4, 16, 32, 64, 132, 144]] + [64 * 70000],
244246
ids=lambda x: f"size{x}")
245247
@pytest.mark.parametrize(
246-
"tp_size",
248+
"mpi_pool_executor",
247249
[2, 4], # 8
248-
ids=["tp_size_2", "tp_size_4"]) # "tp_size_8"
249-
def test_lowprecision_allreduce_acc(dtype, strategy, message_size, tp_size):
250+
ids=["tp_size_2", "tp_size_4"],
251+
indirect=True) # "tp_size_8"
252+
def test_lowprecision_allreduce_acc(dtype, strategy, message_size,
253+
mpi_pool_executor):
250254
"""
251255
Only test for accuracy. For performance testing,
252256
manually call TestLowPrecisionAllreduce(...).test('perf')
253257
"""
254-
with MPIPoolExecutor(max_workers=tp_size) as executor:
255-
results = executor.map(
256-
run_single_rank,
257-
*zip(*[(dtype, strategy, message_size)] * tp_size),
258-
)
259-
for r in results:
260-
assert r is True
258+
tp_size = mpi_pool_executor.num_workers
259+
results = mpi_pool_executor.map(
260+
run_single_rank,
261+
*zip(*[(dtype, strategy, message_size)] * tp_size),
262+
)
263+
for r in results:
264+
assert r is True

0 commit comments

Comments
 (0)