Skip to content

Commit 40df110

Browse files
committed
test: started adding tests for nccl utility routines
1 parent da326a0 commit 40df110

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

tests_nccl/test_ncclutils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Test basic NCCL functionalities in _nccl
2+
Designed to run with n processes
3+
$ mpiexec -n 10 pytest test_nccl.py --with-mpi
4+
"""
5+
from mpi4py import MPI
6+
import numpy as np
7+
import cupy as cp
8+
from numpy.testing import assert_allclose
9+
import pytest
10+
11+
from pylops_mpi import DistributedArray, Partition
12+
from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather
13+
14+
np.random.seed(42)
15+
16+
nccl_comm = initialize_nccl_comm()
17+
18+
par1 = {'n': 3, 'dtype': np.float64}
19+
20+
21+
@pytest.mark.mpi(min_size=2)
22+
@pytest.mark.parametrize("par", [(par1), ])
23+
def test_allgather_samesize(par):
24+
"""Test nccl_allgather with arrays of same size"""
25+
size = MPI.COMM_WORLD.Get_size()
26+
rank = MPI.COMM_WORLD.Get_rank()
27+
28+
# Local array
29+
local_array = rank * cp.ones(par['n'], dtype=par['dtype'])
30+
31+
# Gathered array
32+
gathered_array = nccl_allgather(nccl_comm, local_array)
33+
34+
# Compare with global array created in rank0
35+
if rank == 0:
36+
global_array = np.ones(par['n'] * size, dtype=par['dtype'])
37+
for irank in range(size):
38+
global_array[irank * par["n"]: (irank + 1) * par["n"]] = irank
39+
40+
assert_allclose(
41+
gathered_array.get(),
42+
global_array,
43+
rtol=1e-14,
44+
)
45+
46+
47+
@pytest.mark.mpi(min_size=2)
48+
@pytest.mark.parametrize("par", [(par1), ])
49+
def test_allgather_samesize_withrecbuf(par):
50+
"""Test nccl_allgather with arrays of same size and rec_buf"""
51+
size = MPI.COMM_WORLD.Get_size()
52+
rank = MPI.COMM_WORLD.Get_rank()
53+
54+
# Local array
55+
local_array = rank * cp.ones(par['n'], dtype=par['dtype'])
56+
57+
# Gathered array
58+
gathered_array = cp.zeros(par['n'] * size, dtype=par['dtype'])
59+
gathered_array = nccl_allgather(nccl_comm, local_array, recv_buf=gathered_array)
60+
61+
# Compare with global array created in rank0
62+
if rank == 0:
63+
global_array = np.ones(par['n'] * size, dtype=par['dtype'])
64+
for irank in range(size):
65+
global_array[irank * par["n"]: (irank + 1) * par["n"]] = irank
66+
67+
assert_allclose(
68+
gathered_array.get(),
69+
global_array,
70+
rtol=1e-14,
71+
)
72+
73+
74+
# @pytest.mark.mpi(min_size=2)
75+
# @pytest.mark.parametrize("par", [(par1), ])
76+
# def test_allgather_differentsize_withrecbuf(par):
77+
# """Test nccl_allgather with arrays of different size and rec_buf"""
78+
# size = MPI.COMM_WORLD.Get_size()
79+
# rank = MPI.COMM_WORLD.Get_rank()
80+
81+
# # Local array
82+
# n = par['n'] # + (1 if rank == size - 1 else 0)
83+
# print(f'rank {rank}, n {n}')
84+
# local_array = rank * cp.ones(n, dtype=par['dtype'])
85+
86+
# # Gathered array
87+
# #gathered_array = cp.zeros(par['n'] * size + 1, dtype=par['dtype'])
88+
# gathered_array = cp.zeros(par['n'] * size, dtype=par['dtype'])
89+
# nccl_allgather(nccl_comm, local_array, recv_buf=gathered_array)
90+
91+
# # Compare with global array created in rank0
92+
# # if rank == 0:
93+
# # global_array = np.ones(par['n'] * size + 1, dtype=par['dtype'])
94+
# # for irank in range(size - 1):
95+
# # global_array[irank * par["n"]: (irank + 1) * par["n"]] = irank
96+
# # global_array[(size - 1) * par["n"]:] = size - 1
97+
98+
# # assert_allclose(
99+
# # gathered_array.get(),
100+
# # global_array,
101+
# # rtol=1e-14,
102+
# # )

0 commit comments

Comments
 (0)