Skip to content

Commit f0fc6ba

Browse files
author
Kurt Smith
committed
Make comm first argument in localarray load_* functions.
1 parent d66c262 commit f0fc6ba

File tree

5 files changed

+42
-41
lines changed

5 files changed

+42
-41
lines changed

distarray/dist/context.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -364,20 +364,18 @@ def load_dnpy(self, name):
364364
da_key = self._generate_key()
365365

366366
if isinstance(name, six.string_types):
367-
subs = (da_key,) + self._key_and_push(name) + (self.comm,
368-
self.comm)
367+
subs = (da_key,) + (self.comm,) + self._key_and_push(name) + (self.comm,)
369368
self._execute(
370-
'%s = distarray.local.load_dnpy(%s + "_" + str(%s.Get_rank()) + ".dnpy", %s)' % subs,
369+
'%s = distarray.local.load_dnpy(%s, %s + "_" + str(%s.Get_rank()) + ".dnpy")' % subs,
371370
targets=self.targets
372371
)
373372
elif isinstance(name, collections.Sequence):
374373
if len(name) != len(self.targets):
375374
errmsg = "`name` must be the same length as `self.targets`."
376375
raise TypeError(errmsg)
377-
subs = (da_key,) + self._key_and_push(name) + (self.comm,
378-
self.comm)
376+
subs = (da_key,) + (self.comm,) + self._key_and_push(name) + (self.comm,)
379377
self._execute(
380-
'%s = distarray.local.load_dnpy(%s[%s.Get_rank()], %s)' % subs,
378+
'%s = distarray.local.load_dnpy(%s, %s[%s.Get_rank()])' % subs,
381379
targets=self.targets
382380
)
383381
else:
@@ -438,16 +436,17 @@ def load_npy(self, filename, distribution):
438436
result : DistArray
439437
A DistArray encapsulating the file loaded.
440438
"""
441-
da_key = self._generate_key()
439+
440+
def _local_load_npy(filename, ddpr, comm):
441+
from distarray.local import load_npy
442+
dim_data = ddpr[comm.Get_rank()]
443+
return proxyize(load_npy(comm, filename, dim_data))
444+
442445
ddpr = distribution.get_dim_data_per_rank()
443-
subs = ((da_key,) + self._key_and_push(filename, ddpr) +
444-
(distribution.comm,) + (distribution.comm,))
445446

446-
self._execute(
447-
'%s = distarray.local.load_npy(%s, %s[%s.Get_rank()], %s)' % subs,
448-
targets=distribution.targets
449-
)
450-
return DistArray.from_localarrays(da_key, distribution=distribution)
447+
da_key = self.apply(_local_load_npy, (filename, ddpr, distribution.comm),
448+
targets=distribution.targets)
449+
return DistArray.from_localarrays(da_key[0], distribution=distribution)
451450

452451
def load_hdf5(self, filename, distribution, key='buffer'):
453452
"""
@@ -476,7 +475,7 @@ def load_hdf5(self, filename, distribution, key='buffer'):
476475
def _local_load_hdf5(filename, ddpr, comm, key):
477476
from distarray.local import load_hdf5
478477
dim_data = ddpr[comm.Get_rank()]
479-
return proxyize(load_hdf5(filename, dim_data, comm, key))
478+
return proxyize(load_hdf5(comm, filename, dim_data, key))
480479

481480
ddpr = distribution.get_dim_data_per_rank()
482481

distarray/dist/distarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def _reduce(self, local_reduce_name, axes=None, dtype=None, out=None):
286286
def _local_reduce(local_name, larr, out_comm, ddpr, dtype, axes):
287287
import distarray.local.localarray as la
288288
local_reducer = getattr(la, local_name)
289-
res = proxyize(la.local_reduction(local_reducer, out_comm, larr, # noqa
289+
res = proxyize(la.local_reduction(out_comm, local_reducer, larr, # noqa
290290
ddpr, dtype, axes))
291291
return res
292292

distarray/local/localarray.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def compatibility_hash(self):
198198
#-------------------------------------------------------------------------
199199

200200
@classmethod
201-
def from_distarray(cls, obj, comm):
201+
def from_distarray(cls, comm, obj):
202202
"""Make a LocalArray from Distributed Array Protocol data structure.
203203
204204
An object that supports the Distributed Array Protocol will have
@@ -625,7 +625,7 @@ def save_dnpy(file, arr):
625625
fid.close()
626626

627627

628-
def load_dnpy(file, comm):
628+
def load_dnpy(comm, file):
629629
"""
630630
Load a LocalArray from a ``.dnpy`` file.
631631
@@ -649,7 +649,7 @@ def load_dnpy(file, comm):
649649

650650
try:
651651
distbuffer = format.read_localarray(fid)
652-
return LocalArray.from_distarray(distbuffer, comm=comm)
652+
return LocalArray.from_distarray(comm=comm, obj=distbuffer)
653653

654654
finally:
655655
if own_fid:
@@ -741,7 +741,7 @@ def unstructured_index(dd):
741741
return tuple(index)
742742

743743

744-
def load_hdf5(filename, dim_data, comm, key='buffer'):
744+
def load_hdf5(comm, filename, dim_data, key='buffer'):
745745
"""
746746
Load a LocalArray from an ``.hdf5`` file.
747747
@@ -789,7 +789,7 @@ def load_hdf5(filename, dim_data, comm, key='buffer'):
789789
return LocalArray(distribution=distribution, dtype=dtype, buf=buf)
790790

791791

792-
def load_npy(filename, dim_data, comm):
792+
def load_npy(comm, filename, dim_data):
793793
"""
794794
Load a LocalArray from a ``.npy`` file.
795795
@@ -891,7 +891,7 @@ def get_printoptions():
891891
# Reduction functions
892892
# ---------------------------------------------------------------------------
893893

894-
def local_reduction(reducer, out_comm, larr, ddpr, dtype, axes):
894+
def local_reduction(out_comm, reducer, larr, ddpr, dtype, axes):
895895
""" Entry point for reductions on local arrays.
896896
897897
Parameters

distarray/local/tests/paralleltest_distributed_array_protocol.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ def assert_round_trip_equality(self, larr):
4343
assert_array_equal(larr.ndarray, self.larr.ndarray)
4444

4545
def test_round_trip_equality_from_object(self):
46-
larr = LocalArray.from_distarray(self.larr, comm=self.comm)
46+
larr = LocalArray.from_distarray(comm=self.comm, obj=self.larr)
4747
self.assert_round_trip_equality(larr)
4848

4949
def test_round_trip_equality_from_dict(self):
50-
larr = LocalArray.from_distarray(self.larr.__distarray__(),
51-
comm=self.comm)
50+
larr = LocalArray.from_distarray(comm=self.comm, obj=self.larr.__distarray__())
5251
self.assert_round_trip_equality(larr)
5352

5453

@@ -65,7 +64,7 @@ def test_with_validator(self):
6564
validate_distbuffer(self.larr.__distarray__())
6665

6766
def test_round_trip_elements(self):
68-
larr = LocalArray.from_distarray(self.larr, comm=self.comm)
67+
larr = LocalArray.from_distarray(comm=self.comm, obj=self.larr)
6968
if self.comm.Get_rank() == 0:
7069
idx = (0,) * larr.ndarray.ndim
7170
larr.ndarray[idx] = 99
@@ -183,7 +182,7 @@ def test_values(self):
183182
elif self.comm.Get_rank() == 1:
184183
assert_array_equal(np.arange(30), self.larr.ndarray)
185184

186-
larr = LocalArray.from_distarray(self.larr, comm=self.comm)
185+
larr = LocalArray.from_distarray(comm=self.comm, obj=self.larr)
187186
if self.comm.Get_rank() == 0:
188187
assert_array_equal(np.arange(20), larr.ndarray)
189188
elif self.comm.Get_rank() == 1:

distarray/local/tests/paralleltest_io.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ def test_flat_file_save_with_file_object(self):
4747

4848
def test_flat_file_save_load_with_filename(self):
4949
save_dnpy(self.output_path, self.larr0)
50-
larr1 = load_dnpy(self.output_path, comm=self.comm)
50+
larr1 = load_dnpy(comm=self.comm, file=self.output_path)
5151
self.assertTrue(isinstance(larr1, LocalArray))
5252
assert_allclose(self.larr0, larr1)
5353

5454
def test_flat_file_save_load_with_file_object(self):
5555
save_dnpy(self.output_path, self.larr0)
5656
with open(self.output_path, 'rb') as fp:
57-
larr1 = load_dnpy(fp, comm=self.comm)
57+
larr1 = load_dnpy(comm=self.comm, file=fp)
5858
self.assertTrue(isinstance(larr1, LocalArray))
5959
assert_allclose(self.larr0, larr1)
6060

@@ -162,25 +162,25 @@ def tearDown(self):
162162

163163
def test_load_bn(self):
164164
dim_data_per_rank = bn_test_data
165-
la = load_npy(self.output_path, dim_data_per_rank[self.rank],
166-
comm=self.comm)
165+
la = load_npy(comm=self.comm, filename=self.output_path,
166+
dim_data=dim_data_per_rank[self.rank])
167167
assert_equal(la, self.expected[numpy.newaxis, self.rank])
168168

169169
def test_load_nc(self):
170170
dim_data_per_rank = nc_test_data
171171
expected_slices = [(slice(None), slice(0, None, 2)),
172172
(slice(None), slice(1, None, 2))]
173173

174-
la = load_npy(self.output_path, dim_data_per_rank[self.rank],
175-
comm=self.comm)
174+
la = load_npy(comm=self.comm, filename=self.output_path,
175+
dim_data=dim_data_per_rank[self.rank])
176176
assert_equal(la, self.expected[expected_slices[self.rank]])
177177

178178
def test_load_nu(self):
179179
dim_data_per_rank = nu_test_data
180180
expected_indices = [dd[1]['indices'] for dd in dim_data_per_rank]
181181

182-
la = load_npy(self.output_path, dim_data_per_rank[self.rank],
183-
comm=self.comm)
182+
la = load_npy(comm=self.comm, filename=self.output_path,
183+
dim_data=dim_data_per_rank[self.rank])
184184
assert_equal(la, self.expected[:, expected_indices[self.rank]])
185185

186186

@@ -257,8 +257,9 @@ def tearDown(self):
257257

258258
def test_load_bn(self):
259259
dim_data_per_rank = bn_test_data
260-
la = load_hdf5(self.output_path, dim_data_per_rank[self.rank],
261-
key=self.key, comm=self.comm)
260+
la = load_hdf5(comm=self.comm, filename=self.output_path,
261+
dim_data=dim_data_per_rank[self.rank],
262+
key=self.key)
262263
with self.h5py.File(self.output_path, 'r', driver='mpio',
263264
comm=self.comm) as fp:
264265
assert_equal(la, self.expected[numpy.newaxis, self.rank])
@@ -267,8 +268,9 @@ def test_load_nc(self):
267268
dim_data_per_rank = nc_test_data
268269
expected_slices = [(slice(None), slice(0, None, 2)),
269270
(slice(None), slice(1, None, 2))]
270-
la = load_hdf5(self.output_path, dim_data_per_rank[self.rank],
271-
key=self.key, comm=self.comm)
271+
la = load_hdf5(comm=self.comm, filename=self.output_path,
272+
dim_data=dim_data_per_rank[self.rank],
273+
key=self.key)
272274
with self.h5py.File(self.output_path, 'r', driver='mpio',
273275
comm=self.comm) as fp:
274276
expected_slice = expected_slices[self.rank]
@@ -277,8 +279,9 @@ def test_load_nc(self):
277279
def test_load_nu(self):
278280
dim_data_per_rank = nu_test_data
279281
expected_indices = [dd[1]['indices'] for dd in dim_data_per_rank]
280-
la = load_hdf5(self.output_path, dim_data_per_rank[self.rank],
281-
key=self.key, comm=self.comm)
282+
la = load_hdf5(comm=self.comm, filename=self.output_path,
283+
dim_data=dim_data_per_rank[self.rank],
284+
key=self.key)
282285
with self.h5py.File(self.output_path, 'r', driver='mpio',
283286
comm=self.comm) as fp:
284287
assert_equal(la, self.expected[:, expected_indices[self.rank]])

0 commit comments

Comments
 (0)