Skip to content

Commit 118dd6a

Browse files
committed
Merge pull request #412 from enthought/bugfix/144
Fixes #144.
2 parents 48bd3d7 + f0fc6ba commit 118dd6a

14 files changed

+214
-196
lines changed

distarray/dist/context.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _create_local(self, local_call, distribution, dtype):
218218
ddpr = distribution.get_dim_data_per_rank()
219219
ddpr_name, dtype_name = self._key_and_push(ddpr, dtype)
220220
cmd = ('{da_key} = {local_call}(distarray.local.maps.Distribution('
221-
'{ddpr_name}[{comm_name}.Get_rank()], comm={comm_name}), '
221+
'comm={comm_name}, dim_data={ddpr_name}[{comm_name}.Get_rank()]), '
222222
'dtype={dtype_name})')
223223
self._execute(cmd.format(**locals()), targets=distribution.targets)
224224
return DistArray.from_localarrays(da_key, distribution=distribution,
@@ -364,20 +364,18 @@ def load_dnpy(self, name):
364364
da_key = self._generate_key()
365365

366366
if isinstance(name, six.string_types):
367-
subs = (da_key,) + self._key_and_push(name) + (self.comm,
368-
self.comm)
367+
subs = (da_key,) + (self.comm,) + self._key_and_push(name) + (self.comm,)
369368
self._execute(
370-
'%s = distarray.local.load_dnpy(%s + "_" + str(%s.Get_rank()) + ".dnpy", %s)' % subs,
369+
'%s = distarray.local.load_dnpy(%s, %s + "_" + str(%s.Get_rank()) + ".dnpy")' % subs,
371370
targets=self.targets
372371
)
373372
elif isinstance(name, collections.Sequence):
374373
if len(name) != len(self.targets):
375374
errmsg = "`name` must be the same length as `self.targets`."
376375
raise TypeError(errmsg)
377-
subs = (da_key,) + self._key_and_push(name) + (self.comm,
378-
self.comm)
376+
subs = (da_key,) + (self.comm,) + self._key_and_push(name) + (self.comm,)
379377
self._execute(
380-
'%s = distarray.local.load_dnpy(%s[%s.Get_rank()], %s)' % subs,
378+
'%s = distarray.local.load_dnpy(%s, %s[%s.Get_rank()])' % subs,
381379
targets=self.targets
382380
)
383381
else:
@@ -438,16 +436,17 @@ def load_npy(self, filename, distribution):
438436
result : DistArray
439437
A DistArray encapsulating the file loaded.
440438
"""
441-
da_key = self._generate_key()
439+
440+
def _local_load_npy(filename, ddpr, comm):
441+
from distarray.local import load_npy
442+
dim_data = ddpr[comm.Get_rank()]
443+
return proxyize(load_npy(comm, filename, dim_data))
444+
442445
ddpr = distribution.get_dim_data_per_rank()
443-
subs = ((da_key,) + self._key_and_push(filename, ddpr) +
444-
(distribution.comm,) + (distribution.comm,))
445446

446-
self._execute(
447-
'%s = distarray.local.load_npy(%s, %s[%s.Get_rank()], %s)' % subs,
448-
targets=distribution.targets
449-
)
450-
return DistArray.from_localarrays(da_key, distribution=distribution)
447+
da_key = self.apply(_local_load_npy, (filename, ddpr, distribution.comm),
448+
targets=distribution.targets)
449+
return DistArray.from_localarrays(da_key[0], distribution=distribution)
451450

452451
def load_hdf5(self, filename, distribution, key='buffer'):
453452
"""
@@ -473,16 +472,17 @@ def load_hdf5(self, filename, distribution, key='buffer'):
473472
errmsg = "An MPI-enabled h5py must be available to use load_hdf5."
474473
raise ImportError(errmsg)
475474

476-
da_key = self._generate_key()
475+
def _local_load_hdf5(filename, ddpr, comm, key):
476+
from distarray.local import load_hdf5
477+
dim_data = ddpr[comm.Get_rank()]
478+
return proxyize(load_hdf5(comm, filename, dim_data, key))
479+
477480
ddpr = distribution.get_dim_data_per_rank()
478-
subs = ((da_key,) + self._key_and_push(filename, ddpr) +
479-
(distribution.comm,) + self._key_and_push(key) + (distribution.comm,))
480481

481-
self._execute(
482-
'%s = distarray.local.load_hdf5(%s, %s[%s.Get_rank()], %s, %s)' % subs,
483-
targets=distribution.targets
484-
)
485-
return DistArray.from_localarrays(da_key, distribution=distribution)
482+
da_key = self.apply(_local_load_hdf5, (filename, ddpr, distribution.comm, key),
483+
targets=distribution.targets)
484+
485+
return DistArray.from_localarrays(da_key[0], distribution=distribution)
486486

487487
def fromndarray(self, arr, distribution=None):
488488
"""Create a DistArray from an ndarray.
@@ -530,7 +530,7 @@ def fromfunction(self, function, shape, **kwargs):
530530
comm_name = distribution.comm
531531
cmd = ('{da_name} = distarray.local.fromfunction({function_name}, '
532532
'distarray.local.maps.Distribution('
533-
'{ddpr_name}[{comm_name}.Get_rank()], comm={comm_name}),'
533+
'comm={comm_name}, dim_data={ddpr_name}[{comm_name}.Get_rank()]),'
534534
'**{kwargs_name})')
535535
self._execute(cmd.format(**locals()), targets=distribution.targets)
536536
return DistArray.from_localarrays(da_name, distribution=distribution)

distarray/dist/distarray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def __init__(self, distribution, dtype=float):
4949
ddpr_name, dtype_name = ctx._key_and_push(ddpr, dtype)
5050
cmd = ('{da_key} = distarray.local.empty('
5151
'distarray.local.maps.Distribution('
52-
'{ddpr_name}[{comm_name}.Get_rank()], '
53-
'{comm_name}), {dtype_name})')
52+
'comm={comm_name}, dim_data={ddpr_name}[{comm_name}.Get_rank()]), '
53+
'{dtype_name})')
5454
ctx._execute(cmd.format(**locals()), targets=distribution.targets)
5555
self.distribution = distribution
5656
self.key = da_key
@@ -286,7 +286,7 @@ def _reduce(self, local_reduce_name, axes=None, dtype=None, out=None):
286286
def _local_reduce(local_name, larr, out_comm, ddpr, dtype, axes):
287287
import distarray.local.localarray as la
288288
local_reducer = getattr(la, local_name)
289-
res = proxyize(la.local_reduction(local_reducer, out_comm, larr, # noqa
289+
res = proxyize(la.local_reduction(out_comm, local_reducer, larr, # noqa
290290
ddpr, dtype, axes))
291291
return res
292292

distarray/local/construct.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from distarray.local.mpiutils import MPI
1010
from distarray.local.error import NullCommError, InvalidBaseCommError
11-
from distarray.local import mpiutils
1211

1312

1413
# ---------------------------------------------------------------------------
@@ -27,8 +26,6 @@ def init_base_comm(comm):
2726
"""Sanitize an MPI.comm instance or create one."""
2827
if comm == MPI.COMM_NULL:
2928
raise NullCommError("Cannot create a LocalArray with COMM_NULL")
30-
elif comm is None:
31-
return mpiutils.COMM_PRIVATE
3229
elif isinstance(comm, MPI.Comm):
3330
return comm
3431
else:

distarray/local/localarray.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# ---------------------------------------------------------------------------
1111
# Imports
1212
# ---------------------------------------------------------------------------
13-
import math
1413
from collections import Mapping
1514
from numbers import Integral
1615

@@ -19,7 +18,6 @@
1918
from distarray.externals import six
2019
from distarray.externals.six.moves import zip
2120

22-
import distarray.local
2321
from distarray.local.mpiutils import MPI
2422
from distarray.utils import _raise_nie
2523
from distarray.local import format, maps
@@ -200,7 +198,7 @@ def compatibility_hash(self):
200198
#-------------------------------------------------------------------------
201199

202200
@classmethod
203-
def from_distarray(cls, obj, comm=None):
201+
def from_distarray(cls, comm, obj):
204202
"""Make a LocalArray from Distributed Array Protocol data structure.
205203
206204
An object that supports the Distributed Array Protocol will have
@@ -228,7 +226,7 @@ def from_distarray(cls, obj, comm=None):
228226
buf = np.asarray(distbuffer['buffer'])
229227
dim_data = distbuffer['dim_data']
230228

231-
distribution = maps.Distribution(dim_data=dim_data, comm=comm)
229+
distribution = maps.Distribution(comm=comm, dim_data=dim_data)
232230
return cls(distribution=distribution, buf=buf)
233231

234232
def __distarray__(self):
@@ -627,7 +625,7 @@ def save_dnpy(file, arr):
627625
fid.close()
628626

629627

630-
def load_dnpy(file, comm=None):
628+
def load_dnpy(comm, file):
631629
"""
632630
Load a LocalArray from a ``.dnpy`` file.
633631
@@ -651,7 +649,7 @@ def load_dnpy(file, comm=None):
651649

652650
try:
653651
distbuffer = format.read_localarray(fid)
654-
return LocalArray.from_distarray(distbuffer, comm=comm)
652+
return LocalArray.from_distarray(comm=comm, obj=distbuffer)
655653

656654
finally:
657655
if own_fid:
@@ -743,7 +741,7 @@ def unstructured_index(dd):
743741
return tuple(index)
744742

745743

746-
def load_hdf5(filename, dim_data, key='buffer', comm=None):
744+
def load_hdf5(comm, filename, dim_data, key='buffer'):
747745
"""
748746
Load a LocalArray from an ``.hdf5`` file.
749747
@@ -756,10 +754,10 @@ def load_hdf5(filename, dim_data, key='buffer', comm=None):
756754
https://github.com/enthought/distributed-array-protocol, describing
757755
which portions of the HDF5 file to load into this LocalArray, and with
758756
what metadata.
757+
comm : MPI comm object
759758
key : str, optional
760759
The identifier for the group to load the LocalArray from (the default
761760
is 'buffer').
762-
comm : MPI comm object, optional
763761
764762
Returns
765763
-------
@@ -787,11 +785,11 @@ def load_hdf5(filename, dim_data, key='buffer', comm=None):
787785
buf = dset[index]
788786
dtype = dset.dtype
789787

790-
distribution = maps.Distribution(dim_data=dim_data, comm=comm)
788+
distribution = maps.Distribution(comm=comm, dim_data=dim_data)
791789
return LocalArray(distribution=distribution, dtype=dtype, buf=buf)
792790

793791

794-
def load_npy(filename, dim_data, comm=None):
792+
def load_npy(comm, filename, dim_data):
795793
"""
796794
Load a LocalArray from a ``.npy`` file.
797795
@@ -804,7 +802,7 @@ def load_npy(filename, dim_data, comm=None):
804802
https://github.com/enthought/distributed-array-protocol, describing
805803
which portions of the HDF5 file to load into this LocalArray, and with
806804
what metadata.
807-
comm : MPI comm object, optional
805+
comm : MPI comm object
808806
809807
Returns
810808
-------
@@ -823,7 +821,7 @@ def load_npy(filename, dim_data, comm=None):
823821
# http://stackoverflow.com/questions/6397495/unmap-of-numpy-memmap
824822

825823
#data._mmap.close()
826-
distribution = maps.Distribution(dim_data=dim_data, comm=comm)
824+
distribution = maps.Distribution(comm=comm, dim_data=dim_data)
827825
return LocalArray(distribution=distribution, dtype=data.dtype, buf=buf)
828826

829827

@@ -893,7 +891,7 @@ def get_printoptions():
893891
# Reduction functions
894892
# ---------------------------------------------------------------------------
895893

896-
def local_reduction(reducer, out_comm, larr, ddpr, dtype, axes):
894+
def local_reduction(out_comm, reducer, larr, ddpr, dtype, axes):
897895
""" Entry point for reductions on local arrays.
898896
899897
Parameters
@@ -923,8 +921,8 @@ def local_reduction(reducer, out_comm, larr, ddpr, dtype, axes):
923921
out = None
924922
else:
925923
dim_data = ddpr[out_comm.Get_rank()] if ddpr else ()
926-
dist = distarray.local.maps.Distribution(dim_data, out_comm)
927-
out = distarray.local.empty(dist, dtype)
924+
dist = maps.Distribution(comm=out_comm, dim_data=dim_data)
925+
out = empty(dist, dtype)
928926

929927
remaining_dims = [False] * larr.ndim
930928
for axis in axes:

distarray/local/maps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ class Distribution(object):
3939
Manages one or more one-dimensional map classes.
4040
"""
4141

42-
def __init__(self, dim_data, comm=None):
42+
def __init__(self, comm, dim_data):
4343
"""Create a Distribution from a `dim_data` structure."""
4444
self._maps = tuple(map_from_dim_dict(dim_dict) for dim_dict in dim_data)
4545
self.base_comm = construct.init_base_comm(comm)
4646
self.comm = construct.init_comm(self.base_comm, self.grid_shape)
4747

4848
@classmethod
49-
def from_shape(cls, shape, dist=None, grid_shape=None, comm=None):
49+
def from_shape(cls, comm, shape, dist=None, grid_shape=None):
5050
"""Create a Distribution from a `shape` and optional arguments."""
5151
dist = {0: 'b'} if dist is None else dist
5252
ndim = len(shape)
@@ -72,7 +72,7 @@ def from_shape(cls, shape, dist=None, grid_shape=None, comm=None):
7272
distribute_indices(dim_dict)
7373
dim_data.append(dim_dict)
7474

75-
return cls(dim_data, comm=base_comm)
75+
return cls(comm=base_comm, dim_data=dim_data)
7676

7777
def __getitem__(self, idx):
7878
return self._maps[idx]

0 commit comments

Comments
 (0)