@@ -442,16 +442,24 @@ def compact_number(v):
442
442
def get_grid_dimensions (current_problem_size , params , grid_div , block_size_names ):
443
443
"""Compute grid dims based on problem sizes and listed grid divisors."""
444
444
445
- def get_dimension_divisor (divisor_list , default , params ):
446
- if divisor_list is None :
447
- if default in params :
448
- divisor_list = [default ]
449
- else :
450
- return 1
451
- if callable (divisor_list ):
452
- return divisor_list (params )
445
+ def get_dimension_divisor (divisor , default , params ):
446
+ divisor_num = 1
447
+
448
+ if divisor is None :
449
+ divisor_num = params .get (default , 1 )
450
+ elif isinstance (divisor , int ):
451
+ divisor_num = divisor
452
+ elif callable (divisor ):
453
+ divisor_num = divisor (params )
454
+ elif isinstance (divisor , str ):
455
+ divisor_num = int (eval (replace_param_occurrences (divisor , params )))
456
+ elif np .iterable (divisor ):
457
+ for div in divisor :
458
+ divisor_num *= get_dimension_divisor (div , 1 , params )
453
459
else :
454
- return np .prod ([int (eval (replace_param_occurrences (s , params ))) for s in divisor_list ])
460
+ raise ValueError ("Error: unrecognized type in grid divisor list, should be any of int, str, callable, or iterable" )
461
+
462
+ return divisor_num
455
463
456
464
divisors = [get_dimension_divisor (d , block_size_names [i ], params ) for i , d in enumerate (grid_div )]
457
465
return tuple (int (np .ceil (float (current_problem_size [i ]) / float (d ))) for i , d in enumerate (divisors ))
0 commit comments