Skip to content

Commit 6eedfe1

Browse files
committed
Add support for downsampled outputs in aggregator
1 parent 2785cd3 commit 6eedfe1

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

src/torchio/data/inference/aggregator.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ class GridAggregator:
2424
in the overlapping areas will be weighted with a Hann window
2525
function. See the `grid aggregator tests`_ for a raw visualization
2626
of the three modes.
27+
downsampling_factor: Factor by which the output volume is expected to
28+
be smaller than the input volume in each spatial dimension. This is
29+
useful when the model downsamples the input (e.g., with strided
30+
convolutions or pooling layers). Currently, only a single integer
31+
is supported, which applies the same downsampling factor to all
32+
spatial dimensions.
2733
2834
.. _grid aggregator tests: https://github.com/TorchIO-project/torchio/blob/main/tests/data/inference/test_aggregator.py
2935
@@ -32,7 +38,12 @@ class GridAggregator:
3238
information about patch-based sampling.
3339
"""
3440

35-
def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
41+
def __init__(
42+
self,
43+
sampler: GridSampler,
44+
overlap_mode: str = 'crop',
45+
downsampling_factor: int = 1, # TODO: support one per dimension
46+
):
3647
subject = sampler.subject
3748
self.volume_padded = sampler.padding_mode is not None
3849
self.spatial_shape = subject.spatial_shape
@@ -43,6 +54,9 @@ def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
4354
self.overlap_mode = overlap_mode
4455
self._avgmask_tensor: torch.Tensor | None = None
4556
self._hann_window: torch.Tensor | None = None
57+
self._downsampling_factor = downsampling_factor
58+
shape_array = np.array(subject.spatial_shape) // self._downsampling_factor
59+
self.spatial_shape = tuple(shape_array.tolist())
4660

4761
@staticmethod
4862
def _parse_overlap_mode(overlap_mode):
@@ -137,7 +151,7 @@ def add_batch(
137151
batch_tensor: torch.Tensor,
138152
locations: torch.Tensor,
139153
) -> None:
140-
"""Add batch processed by a CNN to the output prediction volume.
154+
"""Add batch processed by a network to the output prediction volume.
141155
142156
Args:
143157
batch_tensor: 5D tensor, typically the output of a convolutional
@@ -147,12 +161,13 @@ def add_batch(
147161
extracted using ``batch[torchio.LOCATION]``.
148162
"""
149163
batch = batch_tensor.cpu()
150-
locations_array = locations.cpu().numpy()
151-
patch_sizes = locations_array[:, 3:] - locations_array[:, :3]
164+
locations_array = locations.cpu().numpy() // self._downsampling_factor
165+
target_shapes = locations_array[:, 3:] - locations_array[:, :3]
152166
# There should be only one patch size
153-
assert len(np.unique(patch_sizes, axis=0)) == 1
167+
assert len(np.unique(target_shapes, axis=0)) == 1
154168
input_spatial_shape = tuple(batch.shape[-3:])
155-
target_spatial_shape = tuple(patch_sizes[0])
169+
target_spatial_shape_array = target_shapes[0]
170+
target_spatial_shape = tuple(target_spatial_shape_array.tolist())
156171
if input_spatial_shape != target_spatial_shape:
157172
message = (
158173
f'The shape of the input batch, {input_spatial_shape},'

tests/data/inference/test_aggregator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,40 @@ def test_bad_aggregator_shape(self):
142142
inference_batch = torch.stack(patches)
143143
with pytest.raises(RuntimeError):
144144
aggregator.add_batch(inference_batch, batch[tio.LOCATION])
145+
146+
def test_downsampling_model(self):
147+
# This might be useful to compute image embeddings using a sliding window
148+
downsampling_factor = 4 # e.g. patch size in a ViT
149+
embedding_dim = 5
150+
net_input_size = 20
151+
image_size = 40
152+
153+
def network(x):
154+
down = x[
155+
...,
156+
::downsampling_factor,
157+
::downsampling_factor,
158+
::downsampling_factor,
159+
]
160+
embeddings = torch.cat(embedding_dim * [down], dim=1)
161+
return embeddings
162+
163+
tensor = torch.ones(1, image_size, image_size, image_size)
164+
image_name = 'img'
165+
subject = tio.Subject({image_name: tio.ScalarImage(tensor=tensor)})
166+
sampler = tio.data.GridSampler(
167+
subject,
168+
patch_size=net_input_size,
169+
)
170+
aggregator = tio.data.GridAggregator(
171+
sampler,
172+
downsampling_factor=downsampling_factor,
173+
)
174+
loader = tio.SubjectsLoader(sampler, batch_size=3)
175+
for batch in loader:
176+
input_batch = batch[image_name][tio.DATA]
177+
embeddings = network(input_batch)
178+
aggregator.add_batch(embeddings, batch[tio.LOCATION])
179+
output = aggregator.get_output_tensor()
180+
expected_shape = (embedding_dim,) + (image_size // downsampling_factor,) * 3
181+
self.assertEqual(output.shape, expected_shape)

0 commit comments

Comments
 (0)