@@ -350,6 +350,22 @@ def check_padding(padding):
350350 return padding
351351
352352
353+ # simple heuristic to determine suitable shape from min and step
354+ def _determine_shape (min_shape , step , axes ):
355+ is3d = "z" in axes
356+ min_len = 64 if is3d else 256
357+ shape = []
358+ for ax , min_ax , step_ax in zip (axes , min_shape , step ):
359+ if ax in "zyx" and step_ax > 0 :
360+ len_ax = min_ax
361+ while len_ax < min_len :
362+ len_ax += step_ax
363+ shape .append (len_ax )
364+ else :
365+ shape .append (min_ax )
366+ return shape
367+
368+
353369def parse_tiling (tiling , model ):
354370 if tiling is None : # no tiling
355371 return tiling
@@ -361,9 +377,15 @@ def parse_tiling(tiling, model):
361377
362378 input_spec = model .inputs [0 ]
363379 output_spec = model .outputs [0 ]
380+ axes = input_spec .axes
364381
365382 def check_tiling (tiling ):
366383 assert "halo" in tiling and "tile" in tiling
384+ spatial_axes = [ax for ax in axes if ax in "xyz" ]
385+ halo = tiling ["halo" ]
386+ tile = tiling ["tile" ]
387+ assert all (halo .get (ax , 0 ) > 0 for ax in spatial_axes )
388+ assert all (tile .get (ax , 0 ) > 0 for ax in spatial_axes )
367389
368390 if isinstance (tiling , dict ):
369391 check_tiling (tiling )
@@ -374,20 +396,21 @@ def check_tiling(tiling):
374396 # output space and then request the corresponding input tiles
375397 # so we would need to apply the output scale and offset to the
376398 # input shape to compute the tile size and halo here
377- axes = input_spec .axes
378399 shape = input_spec .shape
379400 if not isinstance (shape , list ):
380- # NOTE this might result in very small tiles.
381- # it would be good to have some heuristic to determine a suitable tilesize
382- # from shape.min and shape.step
383- shape = shape . min
401+ shape = _determine_shape ( shape . min , shape . step , axes )
402+ assert isinstance ( shape , list )
403+ assert len ( shape ) == len ( axes )
404+
384405 halo = output_spec .halo
385406 if halo is None :
386407 raise ValueError ("Model does not provide a valid halo to use for tiling with default parameters" )
408+
387409 tiling = {
388410 "halo" : {ax : ha for ax , ha in zip (axes , halo ) if ax in "xyz" },
389411 "tile" : {ax : sh for ax , sh in zip (axes , shape ) if ax in "xyz" },
390412 }
413+ check_tiling (tiling )
391414 else :
392415 tiling = None
393416 else :
0 commit comments