Skip to content

Commit 6e77940

Browse files
authored
Add support for outputs smaller than inputs in aggregator (#1394)
1 parent 2785cd3 commit 6e77940

File tree

4 files changed

+63
-11
lines changed

4 files changed

+63
-11
lines changed

src/torchio/data/inference/aggregator.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ class GridAggregator:
2424
in the overlapping areas will be weighted with a Hann window
2525
function. See the `grid aggregator tests`_ for a raw visualization
2626
of the three modes.
27+
downsampling_factor: Factor by which the output volume is expected to
28+
be smaller than the input volume in each spatial dimension. This is
29+
useful when the model downsamples the input (e.g., with strided
30+
convolutions or pooling layers). Currently, only a single integer
31+
is supported, which applies the same downsampling factor to all
32+
spatial dimensions.
2733
2834
.. _grid aggregator tests: https://github.com/TorchIO-project/torchio/blob/main/tests/data/inference/test_aggregator.py
2935
@@ -32,7 +38,12 @@ class GridAggregator:
3238
information about patch-based sampling.
3339
"""
3440

35-
def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
41+
def __init__(
42+
self,
43+
sampler: GridSampler,
44+
overlap_mode: str = 'crop',
45+
downsampling_factor: int = 1, # TODO: support one per dimension
46+
):
3647
subject = sampler.subject
3748
self.volume_padded = sampler.padding_mode is not None
3849
self.spatial_shape = subject.spatial_shape
@@ -43,6 +54,9 @@ def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
4354
self.overlap_mode = overlap_mode
4455
self._avgmask_tensor: torch.Tensor | None = None
4556
self._hann_window: torch.Tensor | None = None
57+
self._downsampling_factor = downsampling_factor
58+
shape_array = np.array(subject.spatial_shape) // self._downsampling_factor
59+
self.spatial_shape = tuple(shape_array.tolist())
4660

4761
@staticmethod
4862
def _parse_overlap_mode(overlap_mode):
@@ -137,7 +151,7 @@ def add_batch(
137151
batch_tensor: torch.Tensor,
138152
locations: torch.Tensor,
139153
) -> None:
140-
"""Add batch processed by a CNN to the output prediction volume.
154+
"""Add batch processed by a network to the output prediction volume.
141155
142156
Args:
143157
batch_tensor: 5D tensor, typically the output of a convolutional
@@ -147,12 +161,13 @@ def add_batch(
147161
extracted using ``batch[torchio.LOCATION]``.
148162
"""
149163
batch = batch_tensor.cpu()
150-
locations_array = locations.cpu().numpy()
151-
patch_sizes = locations_array[:, 3:] - locations_array[:, :3]
164+
locations_array = locations.cpu().numpy() // self._downsampling_factor
165+
target_shapes = locations_array[:, 3:] - locations_array[:, :3]
152166
# There should be only one patch size
153-
assert len(np.unique(patch_sizes, axis=0)) == 1
167+
assert len(np.unique(target_shapes, axis=0)) == 1
154168
input_spatial_shape = tuple(batch.shape[-3:])
155-
target_spatial_shape = tuple(patch_sizes[0])
169+
target_spatial_shape_array = target_shapes[0]
170+
target_spatial_shape = tuple(target_spatial_shape_array.tolist())
156171
if input_spatial_shape != target_spatial_shape:
157172
message = (
158173
f'The shape of the input batch, {input_spatial_shape},'

src/torchio/datasets/ct_rate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,14 +319,14 @@ def _instantiate_image(self, image_row: pd.Series) -> ScalarImage:
319319
image_row: A pandas Series representing a row from the metadata DataFrame,
320320
containing information about a single image.
321321
"""
322-
image_dict = image_row.to_dict()
323-
filename = image_dict[self._FILENAME_KEY]
322+
image_dict: dict[str, str | dict[str, str]] = image_row.to_dict() # type: ignore[assignment]
323+
filename: str = image_dict[self._FILENAME_KEY] # type: ignore[assignment]
324324
relative_image_path = self._get_image_path(
325325
filename,
326326
load_fixed=self._load_fixed,
327327
)
328328
image_path = self._root_dir / relative_image_path
329-
report_dict = self._extract_report_dict(image_dict)
329+
report_dict = self._extract_report_dict(image_dict) # type: ignore[arg-type]
330330
image_dict[self._report_key] = report_dict
331331
image = ScalarImage(image_path, verify_path=self._verify_paths, **image_dict)
332332
return image

src/torchio/datasets/ixi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _check_exists(root, modalities):
107107
return exists
108108

109109
@staticmethod
110-
def _get_subjects_list(root, modalities):
110+
def _get_subjects_list(root: Path, modalities: Sequence[str]) -> list[Subject]:
111111
# The number of files for each modality is not the same
112112
# E.g. 581 for T1, 578 for T2
113113
# Let's just use the first modality as reference for now
@@ -134,7 +134,7 @@ def _get_subjects_list(root, modalities):
134134
skip_subject = False
135135
if skip_subject:
136136
continue
137-
subjects.append(Subject(**images_dict))
137+
subjects.append(Subject(**images_dict)) # type: ignore[arg-type]
138138
return subjects
139139

140140
def _download(self, root, modalities):

tests/data/inference/test_aggregator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,40 @@ def test_bad_aggregator_shape(self):
142142
inference_batch = torch.stack(patches)
143143
with pytest.raises(RuntimeError):
144144
aggregator.add_batch(inference_batch, batch[tio.LOCATION])
145+
146+
def test_downsampling_model(self):
147+
# This might be useful to compute image embeddings using a sliding window
148+
downsampling_factor = 4 # e.g. patch size in a ViT
149+
embedding_dim = 5
150+
net_input_size = 20
151+
image_size = 40
152+
153+
def network(x):
154+
down = x[
155+
...,
156+
::downsampling_factor,
157+
::downsampling_factor,
158+
::downsampling_factor,
159+
]
160+
embeddings = torch.cat(embedding_dim * [down], dim=1)
161+
return embeddings
162+
163+
tensor = torch.ones(1, image_size, image_size, image_size)
164+
image_name = 'img'
165+
subject = tio.Subject({image_name: tio.ScalarImage(tensor=tensor)})
166+
sampler = tio.data.GridSampler(
167+
subject,
168+
patch_size=net_input_size,
169+
)
170+
aggregator = tio.data.GridAggregator(
171+
sampler,
172+
downsampling_factor=downsampling_factor,
173+
)
174+
loader = tio.SubjectsLoader(sampler, batch_size=3)
175+
for batch in loader:
176+
input_batch = batch[image_name][tio.DATA]
177+
embeddings = network(input_batch)
178+
aggregator.add_batch(embeddings, batch[tio.LOCATION])
179+
output = aggregator.get_output_tensor()
180+
expected_shape = (embedding_dim,) + (image_size // downsampling_factor,) * 3
181+
self.assertEqual(output.shape, expected_shape)

0 commit comments

Comments
 (0)