@@ -24,6 +24,12 @@ class GridAggregator:
2424 in the overlapping areas will be weighted with a Hann window
2525 function. See the `grid aggregator tests`_ for a raw visualization
2626 of the three modes.
27+ downsampling_factor: Factor by which the output volume is expected to
28+ be smaller than the input volume in each spatial dimension. This is
29+ useful when the model downsamples the input (e.g., with strided
30+ convolutions or pooling layers). Currently, only a single integer
31+ is supported, which applies the same downsampling factor to all
32+ spatial dimensions.
2733
2834 .. _grid aggregator tests: https://github.com/TorchIO-project/torchio/blob/main/tests/data/inference/test_aggregator.py
2935
@@ -32,7 +38,12 @@ class GridAggregator:
3238 information about patch-based sampling.
3339 """
3440
35- def __init__ (self , sampler : GridSampler , overlap_mode : str = 'crop' ):
41+ def __init__ (
42+ self ,
43+ sampler : GridSampler ,
44+ overlap_mode : str = 'crop' ,
45+ downsampling_factor : int = 1 , # TODO: support one per dimension
46+ ):
3647 subject = sampler .subject
3748 self .volume_padded = sampler .padding_mode is not None
3849 self .spatial_shape = subject .spatial_shape
@@ -43,6 +54,9 @@ def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
4354 self .overlap_mode = overlap_mode
4455 self ._avgmask_tensor : torch .Tensor | None = None
4556 self ._hann_window : torch .Tensor | None = None
57+ self ._downsampling_factor = downsampling_factor
58+ shape_array = np .array (subject .spatial_shape ) // self ._downsampling_factor
59+ self .spatial_shape = tuple (shape_array .tolist ())
4660
4761 @staticmethod
4862 def _parse_overlap_mode (overlap_mode ):
@@ -137,7 +151,7 @@ def add_batch(
137151 batch_tensor : torch .Tensor ,
138152 locations : torch .Tensor ,
139153 ) -> None :
140- """Add batch processed by a CNN to the output prediction volume.
154+ """Add batch processed by a network to the output prediction volume.
141155
142156 Args:
143157 batch_tensor: 5D tensor, typically the output of a convolutional
@@ -147,12 +161,13 @@ def add_batch(
147161 extracted using ``batch[torchio.LOCATION]``.
148162 """
149163 batch = batch_tensor .cpu ()
150- locations_array = locations .cpu ().numpy ()
151- patch_sizes = locations_array [:, 3 :] - locations_array [:, :3 ]
164+ locations_array = locations .cpu ().numpy () // self . _downsampling_factor
165+ target_shapes = locations_array [:, 3 :] - locations_array [:, :3 ]
152166 # There should be only one patch size
153- assert len (np .unique (patch_sizes , axis = 0 )) == 1
167+ assert len (np .unique (target_shapes , axis = 0 )) == 1
154168 input_spatial_shape = tuple (batch .shape [- 3 :])
155- target_spatial_shape = tuple (patch_sizes [0 ])
169+ target_spatial_shape_array = target_shapes [0 ]
170+ target_spatial_shape = tuple (target_spatial_shape_array .tolist ())
156171 if input_spatial_shape != target_spatial_shape :
157172 message = (
158173 f'The shape of the input batch, { input_spatial_shape } ,'
0 commit comments