Skip to content

Commit 211435d

Browse files
committed
test: generalize tests to any choice of nranks
1 parent 9786dbd commit 211435d

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

tests/test_distributedarray.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,16 @@ def test_distributed_norm(par):
199199
(par8, par9), (par8b, par9b)])
200200
def test_distributed_maskeddot(par1, par2):
201201
"""Test Distributed Dot product with masked array"""
202-
nsub = 3 # number of subcommunicators
202+
# number of subcommunicators
203+
if MPI.COMM_WORLD.Get_size() % 2 == 0:
204+
nsub = 2
205+
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
206+
nsub = 3
207+
else:
208+
pass
203209
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
204210
mask = np.repeat(np.arange(nsub), subsize)
211+
print('subsize, mask', subsize, mask)
205212
# Replicate x1 and x2 as required in masked arrays
206213
x1, x2 = par1['x'], par2['x']
207214
if par1['axis'] != 0:
@@ -227,7 +234,13 @@ def test_distributed_maskeddot(par1, par2):
227234
(par8), (par8b), (par9), (par9b)])
228235
def test_distributed_maskednorm(par):
229236
"""Test Distributed numpy.linalg.norm method with masked array"""
230-
nsub = 3 # number of subcommunicators
237+
# number of subcommunicators
238+
if MPI.COMM_WORLD.Get_size() % 2 == 0:
239+
nsub = 2
240+
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
241+
nsub = 3
242+
else:
243+
pass
231244
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
232245
mask = np.repeat(np.arange(nsub), subsize)
233246
# Replicate x as required in masked arrays

0 commit comments

Comments
 (0)