Skip to content

Commit ad7205d

Browse files
committed
complex dtype support for current ops (excl. fredholm)
1 parent 3526040 commit ad7205d

File tree

9 files changed

+225
-334
lines changed

9 files changed

+225
-334
lines changed

pylops_mpi/utils/_nccl.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"int8": nccl.NCCL_INT8,
2626
"uint32": nccl.NCCL_UINT32,
2727
"uint64": nccl.NCCL_UINT64,
28+
# sending complex array as float with 2x size
29+
"complex64": nccl.NCCL_FLOAT32,
30+
"complex128": nccl.NCCL_FLOAT64,
2831
}
2932

3033

@@ -35,6 +38,13 @@ class NcclOp(IntEnum):
3538
MIN = nccl.NCCL_MIN
3639

3740

41+
def _nccl_buf_size(buf, count=None):
42+
if buf.dtype in ['complex64', 'complex128']:
43+
return 2 * count if count else 2 * buf.size
44+
else:
45+
return count if count else buf.size
46+
47+
3848
def mpi_op_to_nccl(mpi_op) -> NcclOp:
3949
""" Map MPI reduction operation to NCCL equivalent
4050
@@ -155,7 +165,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
155165
nccl_comm.allGather(
156166
send_buf.data.ptr,
157167
recv_buf.data.ptr,
158-
send_buf.size,
168+
_nccl_buf_size(send_buf),
159169
cupy_to_nccl_dtype[str(send_buf.dtype)],
160170
cp.cuda.Stream.null.ptr,
161171
)
@@ -193,7 +203,7 @@ def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) ->
193203
nccl_comm.allReduce(
194204
send_buf.data.ptr,
195205
recv_buf.data.ptr,
196-
send_buf.size,
206+
_nccl_buf_size(send_buf),
197207
cupy_to_nccl_dtype[str(send_buf.dtype)],
198208
mpi_op_to_nccl(op),
199209
cp.cuda.Stream.null.ptr,
@@ -220,7 +230,7 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
220230
local_array[index] = value
221231
nccl_comm.bcast(
222232
local_array[index].data.ptr,
223-
local_array[index].size,
233+
_nccl_buf_size(local_array[index]),
224234
cupy_to_nccl_dtype[str(local_array[index].dtype)],
225235
0,
226236
cp.cuda.Stream.null.ptr,
@@ -302,7 +312,7 @@ def nccl_send(nccl_comm, send_buf, dest, count):
302312
Number of elements to send from `send_buf`.
303313
"""
304314
nccl_comm.send(send_buf.data.ptr,
305-
count,
315+
_nccl_buf_size(send_buf, count),
306316
cupy_to_nccl_dtype[str(send_buf.dtype)],
307317
dest,
308318
cp.cuda.Stream.null.ptr
@@ -325,7 +335,7 @@ def nccl_recv(nccl_comm, recv_buf, source, count=None):
325335
Number of elements to receive.
326336
"""
327337
nccl_comm.recv(recv_buf.data.ptr,
328-
count,
338+
_nccl_buf_size(recv_buf, count),
329339
cupy_to_nccl_dtype[str(recv_buf.dtype)],
330340
source,
331341
cp.cuda.Stream.null.ptr

tests_nccl/test_blockdiag_nccl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
nccl_comm = initialize_nccl_comm()
1919

2020
par1 = {'ny': 101, 'nx': 101, 'dtype': np.float64}
21-
# par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
21+
par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
2222
par2 = {'ny': 301, 'nx': 101, 'dtype': np.float64}
23-
# par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}
23+
par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}
2424

2525
np.random.seed(42)
2626

2727

2828
@pytest.mark.mpi(min_size=2)
29-
@pytest.mark.parametrize("par", [(par1), (par2)])
29+
@pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
3030
def test_blockdiag_nccl(par):
3131
"""Test the MPIBlockDiag with NCCL"""
3232
size = MPI.COMM_WORLD.Get_size()
@@ -71,7 +71,7 @@ def test_blockdiag_nccl(par):
7171

7272

7373
@pytest.mark.mpi(min_size=2)
74-
@pytest.mark.parametrize("par", [(par1), (par2)])
74+
@pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
7575
def test_stacked_blockdiag_nccl(par):
7676
"""Tests for MPIStackedBlogDiag with NCCL"""
7777
size = MPI.COMM_WORLD.Get_size()

tests_nccl/test_derivative_nccl.py

Lines changed: 54 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@
3838
"partition": pylops_mpi.Partition.BROADCAST,
3939
}
4040

41-
# par1j = {
42-
# "nz": 600,
43-
# "dz": 1.0,
44-
# "edge": False,
45-
# "dtype": np.complex128,
46-
# "partition": pylops_mpi.Partition.SCATTER
47-
# }
41+
par1j = {
42+
"nz": 600,
43+
"dz": 1.0,
44+
"edge": False,
45+
"dtype": np.complex128,
46+
"partition": pylops_mpi.Partition.SCATTER
47+
}
4848

4949
par1e = {
5050
"nz": 600,
@@ -70,13 +70,13 @@
7070
"partition": pylops_mpi.Partition.BROADCAST,
7171
}
7272

73-
# par2j = {
74-
# "nz": (100, 151),
75-
# "dz": 1.0,
76-
# "edge": False,
77-
# "dtype": np.complex128,
78-
# "partition": pylops_mpi.Partition.SCATTER
79-
# }
73+
par2j = {
74+
"nz": (100, 151),
75+
"dz": 1.0,
76+
"edge": False,
77+
"dtype": np.complex128,
78+
"partition": pylops_mpi.Partition.SCATTER
79+
}
8080

8181
par2e = {
8282
"nz": (100, 151),
@@ -102,13 +102,13 @@
102102
"partition": pylops_mpi.Partition.BROADCAST,
103103
}
104104

105-
# par3j = {
106-
# "nz": (101, 51, 100),
107-
# "dz": 0.4,
108-
# "edge": True,
109-
# "dtype": np.complex128,
110-
# "partition": pylops_mpi.Partition.SCATTER
111-
# }
105+
par3j = {
106+
"nz": (101, 51, 100),
107+
"dz": 0.4,
108+
"edge": True,
109+
"dtype": np.complex128,
110+
"partition": pylops_mpi.Partition.SCATTER
111+
}
112112

113113
par3e = {
114114
"nz": (101, 51, 100),
@@ -134,13 +134,13 @@
134134
"partition": pylops_mpi.Partition.BROADCAST,
135135
}
136136

137-
# par4j = {
138-
# "nz": (79, 101, 50),
139-
# "dz": 0.4,
140-
# "edge": True,
141-
# "dtype": np.complex128,
142-
# "partition": pylops_mpi.Partition.SCATTER
143-
# }
137+
par4j = {
138+
"nz": (79, 101, 50),
139+
"dz": 0.4,
140+
"edge": True,
141+
"dtype": np.complex128,
142+
"partition": pylops_mpi.Partition.SCATTER
143+
}
144144

145145
par4e = {
146146
"nz": (79, 101, 50),
@@ -188,24 +188,10 @@
188188

189189

190190
@pytest.mark.mpi(min_size=2)
191-
@pytest.mark.parametrize(
192-
"par",
193-
[
194-
(par1),
195-
(par1b),
196-
(par1e),
197-
(par2),
198-
(par2b),
199-
(par2e),
200-
(par3),
201-
(par3b),
202-
(par3e),
203-
(par4),
204-
(par4b),
205-
(par4e),
206-
],
207-
)
208-
def test_first_derivative_forward(par):
191+
@pytest.mark.parametrize("par", [(par1), (par1b), (par1j), (par1e), (par2), (par2b),
192+
(par2j), (par2e), (par3), (par3b), (par3j), (par3e),
193+
(par4), (par4b), (par4j), (par4e)])
194+
def test_first_derivative_forward_nccl(par):
209195
"""MPIFirstDerivative operator (forward stencil)"""
210196
Fop_MPI = pylops_mpi.MPIFirstDerivative(
211197
dims=par["nz"],
@@ -250,24 +236,10 @@ def test_first_derivative_forward(par):
250236

251237

252238
@pytest.mark.mpi(min_size=2)
253-
@pytest.mark.parametrize(
254-
"par",
255-
[
256-
(par1),
257-
(par1b),
258-
(par1e),
259-
(par2),
260-
(par2b),
261-
(par2e),
262-
(par3),
263-
(par3b),
264-
(par3e),
265-
(par4),
266-
(par4b),
267-
(par4e),
268-
],
269-
)
270-
def test_first_derivative_backward(par):
239+
@pytest.mark.parametrize("par", [(par1), (par1b), (par1j), (par1e), (par2), (par2b),
240+
(par2j), (par2e), (par3), (par3b), (par3j), (par3e),
241+
(par4), (par4b), (par4j), (par4e)])
242+
def test_first_derivative_backward_nccl(par):
271243
"""MPIFirstDerivative operator (backward stencil)"""
272244
Fop_MPI = pylops_mpi.MPIFirstDerivative(
273245
dims=par["nz"],
@@ -311,24 +283,10 @@ def test_first_derivative_backward(par):
311283

312284

313285
@pytest.mark.mpi(min_size=2)
314-
@pytest.mark.parametrize(
315-
"par",
316-
[
317-
(par1),
318-
(par1b),
319-
(par1e),
320-
(par2),
321-
(par2b),
322-
(par2e),
323-
(par3),
324-
(par3b),
325-
(par3e),
326-
(par4),
327-
(par4b),
328-
(par4e),
329-
],
330-
)
331-
def test_first_derivative_centered(par):
286+
@pytest.mark.parametrize("par", [(par1), (par1b), (par1j), (par1e), (par2), (par2b),
287+
(par2j), (par2e), (par3), (par3b), (par3j), (par3e),
288+
(par4), (par4b), (par4j), (par4e)])
289+
def test_first_derivative_centered_nccl(par):
332290
"""MPIFirstDerivative operator (centered stencil)"""
333291
for order in [3, 5]:
334292
Fop_MPI = pylops_mpi.MPIFirstDerivative(
@@ -375,24 +333,10 @@ def test_first_derivative_centered(par):
375333

376334

377335
@pytest.mark.mpi(min_size=2)
378-
@pytest.mark.parametrize(
379-
"par",
380-
[
381-
(par1),
382-
(par1b),
383-
(par1e),
384-
(par2),
385-
(par2b),
386-
(par2e),
387-
(par3),
388-
(par3b),
389-
(par3e),
390-
(par4),
391-
(par4b),
392-
(par4e),
393-
],
394-
)
395-
def test_second_derivative_forward(par):
336+
@pytest.mark.parametrize("par", [(par1), (par1b), (par1j), (par1e), (par2), (par2b),
337+
(par2j), (par2e), (par3), (par3b), (par3j), (par3e),
338+
(par4), (par4b), (par4j), (par4e)])
339+
def test_second_derivative_forward_nccl(par):
396340
"""MPISecondDerivative operator (forward stencil)"""
397341
Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative(
398342
dims=par["nz"],
@@ -436,24 +380,10 @@ def test_second_derivative_forward(par):
436380

437381

438382
@pytest.mark.mpi(min_size=2)
439-
@pytest.mark.parametrize(
440-
"par",
441-
[
442-
(par1),
443-
(par1b),
444-
(par1e),
445-
(par2),
446-
(par2b),
447-
(par2e),
448-
(par3),
449-
(par3b),
450-
(par3e),
451-
(par4),
452-
(par4b),
453-
(par4e),
454-
],
455-
)
456-
def test_second_derivative_backward(par):
383+
@pytest.mark.parametrize("par", [(par1), (par1b), (par1j), (par1e), (par2), (par2b),
384+
(par2j), (par2e), (par3), (par3b), (par3j), (par3e),
385+
(par4), (par4b), (par4j), (par4e)])
386+
def test_second_derivative_backward_nccl(par):
457387
"""MPISecondDerivative operator (backward stencil)"""
458388
Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative(
459389
dims=par["nz"],
@@ -497,24 +427,10 @@ def test_second_derivative_backward(par):
497427

498428

499429
@pytest.mark.mpi(min_size=2)
500-
@pytest.mark.parametrize(
501-
"par",
502-
[
503-
(par1),
504-
(par1b),
505-
(par1e),
506-
(par2),
507-
(par2b),
508-
(par2e),
509-
(par3),
510-
(par3b),
511-
(par3e),
512-
(par4),
513-
(par4b),
514-
(par4e),
515-
],
516-
)
517-
def test_second_derivative_centered(par):
430+
@pytest.mark.parametrize("par", [(par1), (par1b), (par1j), (par1e), (par2), (par2b),
431+
(par2j), (par2e), (par3), (par3b), (par3j), (par3e),
432+
(par4), (par4b), (par4j), (par4e)])
433+
def test_second_derivative_centered_nccl(par):
518434
"""MPISecondDerivative operator (centered stencil)"""
519435
Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative(
520436
dims=par["nz"],
@@ -559,7 +475,7 @@ def test_second_derivative_centered(par):
559475

560476
@pytest.mark.mpi(min_size=2)
561477
@pytest.mark.parametrize("par", [(par5), (par5e), (par6), (par6e)])
562-
def test_laplacian(par):
478+
def test_laplacian_nccl(par):
563479
"""MPILaplacian Operator"""
564480
for kind in ["forward", "backward", "centered"]:
565481
Lop_MPI = pylops_mpi.basicoperators.MPILaplacian(
@@ -607,7 +523,7 @@ def test_laplacian(par):
607523

608524
@pytest.mark.mpi(min_size=2)
609525
@pytest.mark.parametrize("par", [(par5), (par5e), (par6), (par6e)])
610-
def test_gradient(par):
526+
def test_gradient_nccl(par):
611527
"""MPIGradient Operator"""
612528
for kind in ["forward", "backward", "centered"]:
613529
Gop_MPI = pylops_mpi.basicoperators.MPIGradient(

0 commit comments

Comments
 (0)