1+ """Test basic NCCL functionalities in _nccl
2+ Designed to run with n processes
3+ $ mpiexec -n 10 pytest test_nccl.py --with-mpi
4+ """
5+ from mpi4py import MPI
6+ import numpy as np
7+ import cupy as cp
8+ from numpy .testing import assert_allclose
9+ import pytest
10+
11+ from pylops_mpi import DistributedArray , Partition
12+ from pylops_mpi .utils ._nccl import initialize_nccl_comm , nccl_allgather
13+
14+ np .random .seed (42 )
15+
16+ nccl_comm = initialize_nccl_comm ()
17+
18+ par1 = {'n' : 3 , 'dtype' : np .float64 }
19+
20+
21+ @pytest .mark .mpi (min_size = 2 )
22+ @pytest .mark .parametrize ("par" , [(par1 ), ])
23+ def test_allgather_samesize (par ):
24+ """Test nccl_allgather with arrays of same size"""
25+ size = MPI .COMM_WORLD .Get_size ()
26+ rank = MPI .COMM_WORLD .Get_rank ()
27+
28+ # Local array
29+ local_array = rank * cp .ones (par ['n' ], dtype = par ['dtype' ])
30+
31+ # Gathered array
32+ gathered_array = nccl_allgather (nccl_comm , local_array )
33+
34+ # Compare with global array created in rank0
35+ if rank == 0 :
36+ global_array = np .ones (par ['n' ] * size , dtype = par ['dtype' ])
37+ for irank in range (size ):
38+ global_array [irank * par ["n" ]: (irank + 1 ) * par ["n" ]] = irank
39+
40+ assert_allclose (
41+ gathered_array .get (),
42+ global_array ,
43+ rtol = 1e-14 ,
44+ )
45+
46+
47+ @pytest .mark .mpi (min_size = 2 )
48+ @pytest .mark .parametrize ("par" , [(par1 ), ])
49+ def test_allgather_samesize_withrecbuf (par ):
50+ """Test nccl_allgather with arrays of same size and rec_buf"""
51+ size = MPI .COMM_WORLD .Get_size ()
52+ rank = MPI .COMM_WORLD .Get_rank ()
53+
54+ # Local array
55+ local_array = rank * cp .ones (par ['n' ], dtype = par ['dtype' ])
56+
57+ # Gathered array
58+ gathered_array = cp .zeros (par ['n' ] * size , dtype = par ['dtype' ])
59+ gathered_array = nccl_allgather (nccl_comm , local_array , recv_buf = gathered_array )
60+
61+ # Compare with global array created in rank0
62+ if rank == 0 :
63+ global_array = np .ones (par ['n' ] * size , dtype = par ['dtype' ])
64+ for irank in range (size ):
65+ global_array [irank * par ["n" ]: (irank + 1 ) * par ["n" ]] = irank
66+
67+ assert_allclose (
68+ gathered_array .get (),
69+ global_array ,
70+ rtol = 1e-14 ,
71+ )
72+
73+
74+ # @pytest.mark.mpi(min_size=2)
75+ # @pytest.mark.parametrize("par", [(par1), ])
76+ # def test_allgather_differentsize_withrecbuf(par):
77+ # """Test nccl_allgather with arrays of different size and rec_buf"""
78+ # size = MPI.COMM_WORLD.Get_size()
79+ # rank = MPI.COMM_WORLD.Get_rank()
80+
81+ # # Local array
82+ # n = par['n'] # + (1 if rank == size - 1 else 0)
83+ # print(f'rank {rank}, n {n}')
84+ # local_array = rank * cp.ones(n, dtype=par['dtype'])
85+
86+ # # Gathered array
87+ # #gathered_array = cp.zeros(par['n'] * size + 1, dtype=par['dtype'])
88+ # gathered_array = cp.zeros(par['n'] * size, dtype=par['dtype'])
89+ # nccl_allgather(nccl_comm, local_array, recv_buf=gathered_array)
90+
91+ # # Compare with global array created in rank0
92+ # # if rank == 0:
93+ # # global_array = np.ones(par['n'] * size + 1, dtype=par['dtype'])
94+ # # for irank in range(size - 1):
95+ # # global_array[irank * par["n"]: (irank + 1) * par["n"]] = irank
96+ # # global_array[(size - 1) * par["n"]:] = size - 1
97+
98+ # # assert_allclose(
99+ # # gathered_array.get(),
100+ # # global_array,
101+ # # rtol=1e-14,
102+ # # )
0 commit comments