@@ -320,6 +320,12 @@ def local_chunks(self, domain, scales, rank=None, broadcast=False):
320320 local_chunks .append (np .arange (start , end ))
321321 return tuple (local_chunks )
322322
323+ def global_elements (self , domain , scales ):
324+ """Global element indices by axis."""
325+ global_shape = self .global_shape (domain , scales )
326+ indices = [np .arange (n ) for n in global_shape ]
327+ return tuple (indices )
328+
323329 def local_elements (self , domain , scales , rank = None , broadcast = False ):
324330 """Local element indices by axis."""
325331 chunk_shape = self .chunk_shape (domain )
@@ -345,12 +351,7 @@ def valid_elements(self, tensorsig, domain, scales, rank=None, broadcast=False):
345351 valid &= basis .valid_elements (tensorsig , grid_space [basis_axes ], elements [basis_axes ])
346352 return valid
347353
348- @CachedMethod
349- def local_group_arrays (self , domain , scales , rank = None , broadcast = False ):
350- """Dense array of local groups (first axis)."""
351- # Make dense array of local elements
352- elements = self .local_elements (domain , scales , rank = rank , broadcast = broadcast )
353- elements = np .array (np .meshgrid (* elements , indexing = 'ij' ))
354+ def _group_arrays (self , elements , domain ):
354355 # Convert to groups basis-by-basis
355356 grid_space = self .grid_space
356357 groups = np .zeros_like (elements )
@@ -360,6 +361,22 @@ def local_group_arrays(self, domain, scales, rank=None, broadcast=False):
360361 groups [basis_axes ] = basis .elements_to_groups (grid_space [basis_axes ], elements [basis_axes ])
361362 return groups
362363
364+ @CachedMethod
365+ def local_group_arrays (self , domain , scales , rank = None , broadcast = False ):
366+ """Dense array of local groups (first axis)."""
367+ # Make dense array of local elements
368+ elements = self .local_elements (domain , scales , rank = rank , broadcast = broadcast )
369+ elements = np .array (np .meshgrid (* elements , indexing = 'ij' ))
370+ return self ._group_arrays (elements , domain )
371+
372+ @CachedMethod
373+ def global_group_arrays (self , domain , scales ):
374+ """Dense array of local groups (first axis)."""
375+ # Make dense array of local elements
376+ elements = self .global_elements (domain , scales )
377+ elements = np .array (np .meshgrid (* elements , indexing = 'ij' ))
378+ return self ._group_arrays (elements , domain )
379+
363380 @CachedMethod
364381 def local_groupsets (self , group_coupling , domain , scales , rank = None , broadcast = False ):
365382 local_groupsets = self .local_group_arrays (domain , scales , rank = rank , broadcast = broadcast ).astype (object )
0 commit comments