@@ -98,7 +98,8 @@ def __shape_setup__(cls, **kwargs):
9898 shape = kwargs .get ('shape' , kwargs .get ('shape_global' ))
9999 dimensions = kwargs .get ('dimensions' )
100100 npoint = kwargs .get ('npoint' , kwargs .get ('npoint_global' ))
101- glb_npoint = SparseDistributor .decompose (npoint , grid .distributor )
101+ distributor = kwargs .get ('distributor' , SparseDistributor )
102+ glb_npoint = distributor .decompose (npoint , grid .distributor )
102103 # Plain SparseFunction construction with npoint.
103104 if shape is None :
104105 loc_shape = (glb_npoint [grid .distributor .myrank ],)
@@ -184,7 +185,6 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
184185
185186 # Given an array or nothing, create dimension and SubFunction
186187 if key is not None :
187- dimensions = (self ._sparse_dim , Dimension (name = 'd' ))
188188 if key .ndim > 2 :
189189 dimensions = (self ._sparse_dim , Dimension (name = 'd' ),
190190 * mkdims ("i" , n = key .ndim - 2 ))
@@ -211,14 +211,21 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
211211 else :
212212 dtype = dtype or self .dtype
213213
214+ # Wether to initialize the subfunction with the provided data
215+ # Useful when rebuilding with a placeholder array only used to
216+ # infer shape and dtype and set the actual data later
217+ if kwargs .get ('init_subfunc' , True ):
218+ init = {'initializer' : key }
219+ else :
220+ init = {}
221+
214222 # Complex coordinates are not valid, so fall back to corresponding
215223 # real floating point type if dtype is complex.
216224 dtype = dtype (0 ).real .__class__
217-
218225 sf = SparseSubFunction (
219226 name = name , dtype = dtype , dimensions = dimensions ,
220- shape = shape , space_order = 0 , initializer = key , alias = self .alias ,
221- distributor = self ._distributor , parent = self
227+ shape = shape , space_order = 0 , alias = self .alias ,
228+ distributor = self ._distributor , parent = self , ** init
222229 )
223230
224231 if self .npoint == 0 :
@@ -230,6 +237,10 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
230237
231238 return sf
232239
240+ @property
241+ def is_local (self ):
242+ return self ._distributor ._is_local
243+
233244 @property
234245 def sparse_position (self ):
235246 return self ._sparse_position
@@ -534,7 +545,7 @@ def _dist_data_scatter(self, data=None):
534545 data = data if data is not None else self .data ._local
535546
536547 # If not using MPI, don't waste time
537- if self ._distributor .nprocs == 1 :
548+ if self ._distributor .nprocs == 1 or self . is_local :
538549 return data
539550
540551 # Compute dist map only once
@@ -556,8 +567,13 @@ def _dist_data_scatter(self, data=None):
556567
557568 def _dist_subfunc_scatter (self , subfunc ):
558569 # If not using MPI, don't waste time
559- if self ._distributor .nprocs == 1 :
560- return {subfunc : subfunc .data }
570+ if self ._distributor .nprocs == 1 or self .is_local :
571+ if self .is_local and self .dist_origin [subfunc ] is not None :
572+ shift = np .array (self .dist_origin [subfunc ], dtype = subfunc .dtype )
573+ subfuncd = subfunc .data ._local - shift
574+ else :
575+ subfuncd = subfunc .data
576+ return {subfunc : subfuncd }
561577
562578 # Compute dist map only once
563579 dmap = self ._dist_datamap
@@ -581,7 +597,7 @@ def _dist_subfunc_scatter(self, subfunc):
581597
582598 def _dist_data_gather (self , data ):
583599 # If not using MPI, don't waste time
584- if self ._distributor .nprocs == 1 :
600+ if self ._distributor .nprocs == 1 or self . is_local :
585601 return
586602
587603 # Compute dist map only once
@@ -612,7 +628,7 @@ def _dist_subfunc_gather(self, sfuncd, subfunc):
612628 except AttributeError :
613629 pass
614630 # If not using MPI, don't waste time
615- if self ._distributor .nprocs == 1 :
631+ if self ._distributor .nprocs == 1 or self . is_local :
616632 return
617633
618634 # Compute dist map only once
0 commit comments