@@ -218,7 +218,7 @@ def _create_local(self, local_call, distribution, dtype):
218218 ddpr = distribution .get_dim_data_per_rank ()
219219 ddpr_name , dtype_name = self ._key_and_push (ddpr , dtype )
220220 cmd = ('{da_key} = {local_call}(distarray.local.maps.Distribution('
221- '{ ddpr_name}[{comm_name}.Get_rank()], comm={comm_name} ), '
221+ 'comm={comm_name}, dim_data={ ddpr_name}[{comm_name}.Get_rank()]), '
222222 'dtype={dtype_name})' )
223223 self ._execute (cmd .format (** locals ()), targets = distribution .targets )
224224 return DistArray .from_localarrays (da_key , distribution = distribution ,
@@ -364,20 +364,18 @@ def load_dnpy(self, name):
364364 da_key = self ._generate_key ()
365365
366366 if isinstance (name , six .string_types ):
367- subs = (da_key ,) + self ._key_and_push (name ) + (self .comm ,
368- self .comm )
367+ subs = (da_key ,) + (self .comm ,) + self ._key_and_push (name ) + (self .comm ,)
369368 self ._execute (
370- '%s = distarray.local.load_dnpy(%s + "_" + str(%s.Get_rank()) + ".dnpy", %s )' % subs ,
369+ '%s = distarray.local.load_dnpy(%s, %s + "_" + str(%s.Get_rank()) + ".dnpy")' % subs ,
371370 targets = self .targets
372371 )
373372 elif isinstance (name , collections .Sequence ):
374373 if len (name ) != len (self .targets ):
375374 errmsg = "`name` must be the same length as `self.targets`."
376375 raise TypeError (errmsg )
377- subs = (da_key ,) + self ._key_and_push (name ) + (self .comm ,
378- self .comm )
376+ subs = (da_key ,) + (self .comm ,) + self ._key_and_push (name ) + (self .comm ,)
379377 self ._execute (
380- '%s = distarray.local.load_dnpy(%s[%s.Get_rank()], %s )' % subs ,
378+ '%s = distarray.local.load_dnpy(%s, %s [%s.Get_rank()])' % subs ,
381379 targets = self .targets
382380 )
383381 else :
@@ -438,16 +436,17 @@ def load_npy(self, filename, distribution):
438436 result : DistArray
439437 A DistArray encapsulating the file loaded.
440438 """
441- da_key = self ._generate_key ()
439+
440+ def _local_load_npy (filename , ddpr , comm ):
441+ from distarray .local import load_npy
442+ dim_data = ddpr [comm .Get_rank ()]
443+ return proxyize (load_npy (comm , filename , dim_data ))
444+
442445 ddpr = distribution .get_dim_data_per_rank ()
443- subs = ((da_key ,) + self ._key_and_push (filename , ddpr ) +
444- (distribution .comm ,) + (distribution .comm ,))
445446
446- self ._execute (
447- '%s = distarray.local.load_npy(%s, %s[%s.Get_rank()], %s)' % subs ,
448- targets = distribution .targets
449- )
450- return DistArray .from_localarrays (da_key , distribution = distribution )
447+ da_key = self .apply (_local_load_npy , (filename , ddpr , distribution .comm ),
448+ targets = distribution .targets )
449+ return DistArray .from_localarrays (da_key [0 ], distribution = distribution )
451450
452451 def load_hdf5 (self , filename , distribution , key = 'buffer' ):
453452 """
@@ -473,16 +472,17 @@ def load_hdf5(self, filename, distribution, key='buffer'):
473472 errmsg = "An MPI-enabled h5py must be available to use load_hdf5."
474473 raise ImportError (errmsg )
475474
476- da_key = self ._generate_key ()
475+ def _local_load_hdf5 (filename , ddpr , comm , key ):
476+ from distarray .local import load_hdf5
477+ dim_data = ddpr [comm .Get_rank ()]
478+ return proxyize (load_hdf5 (comm , filename , dim_data , key ))
479+
477480 ddpr = distribution .get_dim_data_per_rank ()
478- subs = ((da_key ,) + self ._key_and_push (filename , ddpr ) +
479- (distribution .comm ,) + self ._key_and_push (key ) + (distribution .comm ,))
480481
481- self ._execute (
482- '%s = distarray.local.load_hdf5(%s, %s[%s.Get_rank()], %s, %s)' % subs ,
483- targets = distribution .targets
484- )
485- return DistArray .from_localarrays (da_key , distribution = distribution )
482+ da_key = self .apply (_local_load_hdf5 , (filename , ddpr , distribution .comm , key ),
483+ targets = distribution .targets )
484+
485+ return DistArray .from_localarrays (da_key [0 ], distribution = distribution )
486486
487487 def fromndarray (self , arr , distribution = None ):
488488 """Create a DistArray from an ndarray.
@@ -530,7 +530,7 @@ def fromfunction(self, function, shape, **kwargs):
530530 comm_name = distribution .comm
531531 cmd = ('{da_name} = distarray.local.fromfunction({function_name}, '
532532 'distarray.local.maps.Distribution('
533- '{ ddpr_name}[{comm_name}.Get_rank()], comm={comm_name} ),'
533+ 'comm={comm_name}, dim_data={ ddpr_name}[{comm_name}.Get_rank()]),'
534534 '**{kwargs_name})' )
535535 self ._execute (cmd .format (** locals ()), targets = distribution .targets )
536536 return DistArray .from_localarrays (da_name , distribution = distribution )
0 commit comments