3636 normalize_grid_shape ,
3737 make_grid_shape ,
3838 positivify ,
39- validate_grid_shape ,
4039 _start_stop_block ,
4140 normalize_dim_dict ,
4241 normalize_reduction_axes )
@@ -412,8 +411,8 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None):
412411 self .ndim = len (dd0 )
413412 self .dist = tuple (dd ['dist_type' ] for dd in dd0 )
414413 self .grid_shape = tuple (dd ['proc_grid_size' ] for dd in dd0 )
415-
416- validate_grid_shape ( self . grid_shape , self .dist , len (self .targets ))
414+ self . grid_shape = normalize_grid_shape ( self . grid_shape , self . ndim ,
415+ self .dist , len (self .targets ))
417416
418417 coords = [tuple (d ['proc_grid_rank' ] for d in dd ) for dd in
419418 dim_data_per_rank ]
@@ -439,7 +438,18 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None):
439438 return self
440439
441440 @classmethod
442- def from_shape (cls , context , shape , dist = None , grid_shape = None , targets = None ):
441+ def from_shape (cls , context , shape , dist = None , grid_shape = None ,
442+ targets = None ):
443+
444+ # special case when dist is all 'n's.
445+ if (dist is not None ) and all (d == 'n' for d in dist ):
446+ if (targets is not None ) and (len (targets ) != 1 ):
447+ raise ValueError ('target dist conflict' )
448+ elif targets is None :
449+ targets = [context .targets [0 ]]
450+ else :
451+ # then targets is set correctly
452+ pass
443453
444454 self = cls .__new__ (cls )
445455 self .context = context
@@ -448,17 +458,18 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None, targets=None):
448458 self .shape = shape
449459 self .ndim = len (shape )
450460
461+ # dist
451462 if dist is None :
452463 dist = {0 : 'b' }
453464 self .dist = normalize_dist (dist , self .ndim )
454465
455- if grid_shape is None : # Make a new grid_shape if not provided.
456- self . grid_shape = make_grid_shape ( self . shape , self . dist ,
457- len (self .targets ))
458- else : # Otherwise normalize the one passed in.
459- self . grid_shape = normalize_grid_shape ( grid_shape , self . ndim )
460- # In either case, validate.
461- validate_grid_shape ( self . grid_shape , self .dist , len (self .targets ))
466+ # grid_shape
467+ if grid_shape is None :
468+ grid_shape = make_grid_shape (self .shape , self . dist ,
469+ len ( self . targets ))
470+
471+ self . grid_shape = normalize_grid_shape ( grid_shape , self . ndim ,
472+ self .dist , len (self .targets ))
462473
463474 # TODO: FIXME: assert that self.rank_from_coords is valid and conforms
464475 # to how MPI does it.
@@ -568,7 +579,8 @@ def __init__(self, context, global_dim_data, targets=None):
568579 self .dist = tuple (m .dist for m in self .maps )
569580 self .grid_shape = tuple (m .grid_size for m in self .maps )
570581
571- validate_grid_shape (self .grid_shape , self .dist , len (context .targets ))
582+ self .grid_shape = normalize_grid_shape (self .grid_shape , self .ndim ,
583+ self .dist , len (context .targets ))
572584
573585 nelts = reduce (operator .mul , self .grid_shape , 1 )
574586 self .rank_from_coords = np .arange (nelts ).reshape (self .grid_shape )
0 commit comments