@@ -414,8 +414,8 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
414414 return {"tile" : tile , "halo" : halo }
415415
416416 if torch .cuda .is_available ():
417- # We always use the same default halo.
418- halo = {"x" : 64 , "y" : 64 , "z" : 16 } # before 64,64,8
417+ # The default halo size .
418+ halo = {"x" : 64 , "y" : 64 , "z" : 16 }
419419
420420 # Determine the GPU RAM and derive a suitable tiling.
421421 vram = torch .cuda .get_device_properties (0 ).total_memory / 1e9
@@ -426,9 +426,11 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
426426 tile = {"x" : 512 , "y" : 512 , "z" : 64 }
427427 elif vram >= 20 :
428428 tile = {"x" : 352 , "y" : 352 , "z" : 48 }
429+ elif vram >= 10 :
430+ tile = {"x" : 256 , "y" : 256 , "z" : 32 }
431+ halo = {"x" : 64 , "y" : 64 , "z" : 8 } # Choose a smaller halo in z.
429432 else :
430- # TODO determine tilings for smaller VRAM
431- raise NotImplementedError (f"Estimating the tile size for a GPU with { vram } GB is not yet supported." )
433+ raise NotImplementedError (f"Infererence with a GPU with { vram } GB VRAM is not supported." )
432434
433435 print (f"Determined tile size: { tile } " )
434436 tiling = {"tile" : tile , "halo" : halo }
0 commit comments