@@ -119,4 +119,46 @@ def test_stacked_vstack_nccl(par):
119119 assert_allclose (y_rmat_mpi .get (), y_rmat_np , rtol = 1e-14 )
120120
121121
122- # TODO: Test of HStack
122+ @pytest .mark .mpi (min_size = 2 )
123+ @pytest .mark .parametrize ("par" , [(par1 ), (par2 )])
124+ def test_hstack (par ):
125+ """Test the MPIHStack operator with NCCL"""
126+ size = MPI .COMM_WORLD .Get_size ()
127+ rank = MPI .COMM_WORLD .Get_rank ()
128+ A_gpu = cp .ones (shape = (par ['ny' ], par ['nx' ])) + par ['imag' ] * cp .ones (shape = (par ['ny' ], par ['nx' ]))
129+ Op = pylops .MatrixMult (A = ((rank + 1 ) * A_gpu ).astype (par ['dtype' ]))
130+ HStack_MPI = pylops_mpi .MPIHStack (ops = [Op , ], base_comm_nccl = nccl_comm )
131+
132+ # Scattered DistributedArray
133+ x = pylops_mpi .DistributedArray (global_shape = size * par ['nx' ],
134+ base_comm_nccl = nccl_comm ,
135+ partition = pylops_mpi .Partition .SCATTER ,
136+ dtype = par ['dtype' ],
137+ engine = "cupy" )
138+ x [:] = cp .ones (shape = par ['nx' ], dtype = par ['dtype' ])
139+ x_global = x .asarray ()
140+
141+ # Broadcasted DistributedArray(global_shape == local_shape)
142+ y = pylops_mpi .DistributedArray (global_shape = par ['ny' ],
143+ base_comm_nccl = nccl_comm ,
144+ partition = pylops_mpi .Partition .BROADCAST ,
145+ dtype = par ['dtype' ],
146+ engine = "cupy" )
147+ y [:] = cp .ones (shape = par ['ny' ], dtype = par ['dtype' ])
148+ y_global = y .asarray ()
149+
150+ x_mat = HStack_MPI @ x
151+ y_rmat = HStack_MPI .H @ y
152+ assert isinstance (x_mat , pylops_mpi .DistributedArray )
153+ assert isinstance (y_rmat , pylops_mpi .DistributedArray )
154+
155+ x_mat_mpi = x_mat .asarray ()
156+ y_rmat_mpi = y_rmat .asarray ()
157+
158+ if rank == 0 :
159+ ops = [pylops .MatrixMult (A = ((i + 1 ) * A_gpu .get ()).astype (par ['dtype' ])) for i in range (size )]
160+ HStack = pylops .HStack (ops = ops )
161+ x_mat_np = HStack @ x_global .get ()
162+ y_rmat_np = HStack .H @ y_global .get ()
163+ assert_allclose (x_mat_mpi .get (), x_mat_np , rtol = 1e-14 )
164+ assert_allclose (y_rmat_mpi .get (), y_rmat_np , rtol = 1e-14 )
0 commit comments