@@ -437,16 +437,18 @@ def compact_number(v):
437
437
def get_grid_dimensions (current_problem_size , params , grid_div , block_size_names ):
438
438
"""Compute grid dims based on problem sizes and listed grid divisors."""
439
439
440
- def get_dimension_divisor (divisor_list , default , params ):
441
- if divisor_list is None :
442
- if default in params :
443
- divisor_list = [default ]
444
- else :
445
- return 1
446
- if callable (divisor_list ):
447
- return divisor_list (params )
440
+ def get_dimension_divisor (divisor , default , params ):
441
+ if divisor is None :
442
+ divisor = params .get (default , 1 )
443
+
444
+ if isinstance (divisor , int ):
445
+ return divisor
446
+ elif callable (divisor ):
447
+ return divisor (params )
448
+ elif isinstance (divisor , str ):
449
+ return int (eval (replace_param_occurrences (divisor , params )))
448
450
else :
449
- return np .prod ([int ( eval ( replace_param_occurrences ( s , params ))) for s in divisor_list ])
451
+ return np .prod ([get_dimension_divisor ( s , 1 , params ) for s in divisor ])
450
452
451
453
divisors = [get_dimension_divisor (d , block_size_names [i ], params ) for i , d in enumerate (grid_div )]
452
454
return tuple (int (np .ceil (float (current_problem_size [i ]) / float (d ))) for i , d in enumerate (divisors ))
0 commit comments