@@ -60,7 +60,7 @@ class ImageSlicer:
6060 Helper class to slice image into tiles and merge them back
6161 """
6262
63- def __init__ (self , image_shape , tile_size , tile_step = 0 , image_margin = 0 , weight = "mean" ):
63+ def __init__ (self , image_shape : Tuple [ int , int ] , tile_size , tile_step = 0 , image_margin = 0 , weight = "mean" ):
6464 """
6565
6666 :param image_shape: Shape of the source image (H, W)
@@ -122,12 +122,6 @@ def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="
122122 else :
123123 margin_left = margin_right = margin_top = margin_bottom = image_margin
124124
125- if (self .image_width + margin_left + margin_right ) % self .tile_size [1 ] != 0 :
126- raise ValueError ()
127-
128- if (self .image_height + margin_top + margin_bottom ) % self .tile_size [0 ] != 0 :
129- raise ValueError ()
130-
131125 self .margin_left = margin_left
132126 self .margin_right = margin_right
133127 self .margin_top = margin_top
@@ -337,6 +331,10 @@ def integrate_batch(self, batch: torch.Tensor, crop_coords):
337331 if batch .device != self .image .device :
338332 batch = batch .to (device = self .image .device )
339333
334+ # Ensure that input batch dtype match the target dtyle of the accumulator
335+ if batch .dtype != self .image .dtype :
336+ batch = batch .type_as (self .image )
337+
340338 for tile , (x , y , tile_width , tile_height ) in zip (batch , crop_coords ):
341339 self .image [:, y : y + tile_height , x : x + tile_width ] += tile * self .weight
342340 self .norm_mask [:, y : y + tile_height , x : x + tile_width ] += self .weight
0 commit comments