Skip to content

Commit aee9add

Browse files
committed
test: fix tests not working with high nranks and change GA to test that
1 parent da326a0 commit aee9add

File tree

5 files changed

+56
-18
lines changed

5 files changed

+56
-18
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
os: [ubuntu-latest, macos-latest]
1717
python-version: ['3.10', '3.11', '3.12', '3.13']
1818
mpi: ['mpich', 'openmpi', 'intelmpi']
19-
rank: ['2', '3', '4']
19+
rank: ['2', '4', '9']
2020
exclude:
2121
- os: macos-latest
2222
mpi: 'intelmpi'

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
PIP := $(shell command -v pip3 2> /dev/null || command which pip 2> /dev/null)
22
PYTHON := $(shell command -v python3 2> /dev/null || command which python 2> /dev/null)
3-
NUM_PROCESSES = 3
3+
NUM_PROCESSES = 4
44

55
.PHONY: install dev-install dev-install_nccl install_ \
66
conda install_conda_nccl dev-install_conda dev-install_conda_nccl \

tests/test_distributedarray.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,21 @@ def test_distributed_norm(par):
199199
def test_distributed_masked(par):
200200
"""Test Asarray with masked array"""
201201
# Number of subcommunicators
202-
if MPI.COMM_WORLD.Get_size() % 2 == 0:
202+
size = MPI.COMM_WORLD.Get_size()
203+
204+
# Exclude not handled cases
205+
shape_axis = par['x'].shape[par['axis']]
206+
print('shape_axis, size', shape_axis, size, shape_axis % size != 0)
207+
if shape_axis % size != 0:
208+
pytest.skip(f"Array dimension to distributed ({shape_axis}) is not "
209+
f"divisible by the number of processes ({size})...")
210+
if size % 2 == 0:
203211
nsub = 2
204-
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
212+
elif size % 3 == 0:
205213
nsub = 3
206214
else:
207-
pass
215+
pytest.skip(f"Number of processes ({size}) is not divisible "
216+
"by 2 or 3...")
208217
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
209218
mask = np.repeat(np.arange(nsub), subsize)
210219

@@ -236,12 +245,21 @@ def test_distributed_masked(par):
236245
def test_distributed_maskeddot(par1, par2):
237246
"""Test Distributed Dot product with masked array"""
238247
# Number of subcommunicators
239-
if MPI.COMM_WORLD.Get_size() % 2 == 0:
248+
size = MPI.COMM_WORLD.Get_size()
249+
250+
# Exclude not handled cases
251+
shape_axis = par1['x'].shape[par1['axis']]
252+
print('shape_axis, size', shape_axis, size, shape_axis % size != 0)
253+
if shape_axis % size != 0:
254+
pytest.skip(f"Array dimension to distributed ({shape_axis}) is not "
255+
f"divisible by the number of processes ({size})...")
256+
if size % 2 == 0:
240257
nsub = 2
241-
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
258+
elif size % 3 == 0:
242259
nsub = 3
243260
else:
244-
pass
261+
pytest.skip(f"Number of processes ({size}) is not divisible "
262+
"by 2 or 3...")
245263
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
246264
mask = np.repeat(np.arange(nsub), subsize)
247265

@@ -271,12 +289,21 @@ def test_distributed_maskeddot(par1, par2):
271289
def test_distributed_maskednorm(par):
272290
"""Test Distributed numpy.linalg.norm method with masked array"""
273291
# Number of subcommunicators
274-
if MPI.COMM_WORLD.Get_size() % 2 == 0:
292+
size = MPI.COMM_WORLD.Get_size()
293+
294+
# Exclude not handled cases
295+
shape_axis = par['x'].shape[par['axis']]
296+
print('shape_axis, size', shape_axis, size, shape_axis % size != 0)
297+
if shape_axis % size != 0:
298+
pytest.skip(f"Array dimension to distributed ({shape_axis}) is not "
299+
f"divisible by the number of processes ({size})...")
300+
if size % 2 == 0:
275301
nsub = 2
276-
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
302+
elif size % 3 == 0:
277303
nsub = 3
278304
else:
279-
pass
305+
pytest.skip(f"Number of processes ({size}) is not divisible "
306+
"by 2 or 3...")
280307
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
281308
mask = np.repeat(np.arange(nsub), subsize)
282309
# Replicate x as required in masked arrays

tests/test_fredholm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
size = MPI.COMM_WORLD.Get_size()
2121

2222
par1 = {
23-
"nsl": 12,
23+
"nsl": 21,
2424
"ny": 6,
2525
"nx": 4,
2626
"nz": 5,
@@ -30,7 +30,7 @@
3030
"dtype": "float32",
3131
} # real, saved Gt
3232
par2 = {
33-
"nsl": 12,
33+
"nsl": 21,
3434
"ny": 6,
3535
"nx": 4,
3636
"nz": 5,
@@ -40,7 +40,7 @@
4040
"dtype": "float32",
4141
} # real, unsaved Gt
4242
par3 = {
43-
"nsl": 12,
43+
"nsl": 21,
4444
"ny": 6,
4545
"nx": 4,
4646
"nz": 5,
@@ -50,7 +50,7 @@
5050
"dtype": "complex64",
5151
} # complex, saved Gt
5252
par4 = {
53-
"nsl": 12,
53+
"nsl": 21,
5454
"ny": 6,
5555
"nx": 4,
5656
"nz": 5,
@@ -60,7 +60,7 @@
6060
"dtype": "complex64",
6161
} # complex, unsaved Gt
6262
par5 = {
63-
"nsl": 12,
63+
"nsl": 21,
6464
"ny": 6,
6565
"nx": 4,
6666
"nz": 1,
@@ -70,7 +70,7 @@
7070
"dtype": "float32",
7171
} # real, saved Gt, nz=1
7272
par6 = {
73-
"nsl": 12,
73+
"nsl": 21,
7474
"ny": 6,
7575
"nx": 4,
7676
"nz": 1,

tests/test_solver.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
StackedDistributedArray
2727
)
2828

29-
np.random.seed(42)
3029
size = MPI.COMM_WORLD.Get_size()
3130
rank = MPI.COMM_WORLD.Get_rank()
3231

@@ -94,6 +93,8 @@
9493
)
9594
def test_cg(par):
9695
"""CG with MPIBlockDiag"""
96+
np.random.seed(42)
97+
9798
A = np.ones((par["ny"], par["nx"])) + par[
9899
"imag"] * np.ones((par["ny"], par["nx"]))
99100
Aop = MatrixMult(np.conj(A.T) @ A, dtype=par['dtype'])
@@ -139,6 +140,8 @@ def test_cg(par):
139140
)
140141
def test_cgls(par):
141142
"""CGLS with MPIBlockDiag"""
143+
np.random.seed(42)
144+
142145
A = np.ones((par["ny"], par["nx"])) + par[
143146
"imag"] * np.ones((par["ny"], par["nx"]))
144147
Aop = MatrixMult(np.conj(A.T) @ A + 1e-5 * np.eye(par["nx"], dtype=par['dtype']),
@@ -186,6 +189,8 @@ def test_cgls(par):
186189
)
187190
def test_cgls_broadcastdata(par):
188191
"""CGLS with broadcasted data vector"""
192+
np.random.seed(42)
193+
189194
A = (rank + 1) * np.ones((par["ny"], par["nx"])) + (rank + 2) * par[
190195
"imag"
191196
] * np.ones((par["ny"], par["nx"]))
@@ -232,6 +237,8 @@ def test_cgls_broadcastdata(par):
232237
)
233238
def test_cgls_broadcastmodel(par):
234239
"""CGLS with broadcasted model vector"""
240+
np.random.seed(42)
241+
235242
A = np.ones((par["ny"], par["nx"])) + par[
236243
"imag"] * np.ones((par["ny"], par["nx"]))
237244
Aop = MatrixMult(np.conj(A.T) @ A + 1e-5 * np.eye(par["nx"], dtype=par['dtype']),
@@ -281,6 +288,8 @@ def test_cgls_broadcastmodel(par):
281288
)
282289
def test_cg_stacked(par):
283290
"""CG with MPIStackedBlockDiag"""
291+
np.random.seed(42)
292+
284293
A = np.ones((par["ny"], par["nx"])) + par[
285294
"imag"] * np.ones((par["ny"], par["nx"]))
286295
Aop = MatrixMult(np.conj(A.T) @ A + 1e-5 * np.eye(par["nx"], dtype=par['dtype']),
@@ -344,6 +353,8 @@ def test_cg_stacked(par):
344353
)
345354
def test_cgls_stacked(par):
346355
"""CGLS with MPIStackedBlockDiag"""
356+
np.random.seed(42)
357+
347358
A = np.ones((par["ny"], par["nx"])) + par[
348359
"imag"] * np.ones((par["ny"], par["nx"]))
349360
Aop = MatrixMult(np.conj(A.T) @ A + 1e-5 * np.eye(par["nx"], dtype=par['dtype']),

0 commit comments

Comments
 (0)