@@ -437,16 +437,18 @@ def compact_number(v):
437437def get_grid_dimensions (current_problem_size , params , grid_div , block_size_names ):
438438 """Compute grid dims based on problem sizes and listed grid divisors."""
439439
440- def get_dimension_divisor (divisor_list , default , params ):
441- if divisor_list is None :
442- if default in params :
443- divisor_list = [default ]
444- else :
445- return 1
446- if callable (divisor_list ):
447- return divisor_list (params )
440+ def get_dimension_divisor (divisor , default , params ):
441+ if divisor is None :
442+ divisor = params .get (default , 1 )
443+
444+ if isinstance (divisor , int ):
445+ return divisor
446+ elif callable (divisor ):
447+ return divisor (params )
448+ elif isinstance (divisor , str ):
449+ return int (eval (replace_param_occurrences (divisor , params )))
448450 else :
449- return np .prod ([int ( eval ( replace_param_occurrences ( s , params ))) for s in divisor_list ])
451+ return np .prod ([get_dimension_divisor ( s , 1 , params ) for s in divisor ])
450452
451453 divisors = [get_dimension_divisor (d , block_size_names [i ], params ) for i , d in enumerate (grid_div )]
452454 return tuple (int (np .ceil (float (current_problem_size [i ]) / float (d ))) for i , d in enumerate (divisors ))
0 commit comments