Skip to content

Commit 5f28b56

Browse files
StijnvWijnStijn van Wijngaarden
andauthored
Stop copying whole image before cropping (#1308)
Co-authored-by: Stijn van Wijngaarden <[email protected]>
1 parent 1471062 commit 5f28b56

File tree

3 files changed

+104
-13
lines changed

3 files changed

+104
-13
lines changed

src/torchio/transforms/preprocessing/spatial/crop.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import nibabel as nib
1+
from copy import deepcopy
2+
23
import numpy as np
4+
from nibabel.affines import apply_affine
35

6+
from ....data.image import Image
47
from ....data.subject import Subject
58
from .bounds_transform import BoundsTransform
69
from .bounds_transform import TypeBounds
@@ -24,33 +27,101 @@ class Crop(BoundsTransform):
2427
If only one value :math:`n` is provided, then
2528
:math:`w_{ini} = w_{fin} = h_{ini} = h_{fin}
2629
= d_{ini} = d_{fin} = n`.
30+
copy: If ``True``, each image will be cropped and the patch copied to a new
31+
subject. If ``False``, each image will be cropped in place. This transform
32+
overwrites the copy argument of the base transform and copies only the
33+
cropped patch instead of the whole image. This can provide a significant
34+
speedup when cropping small patches from large images.
2735
**kwargs: See :class:`~torchio.transforms.Transform` for additional
2836
keyword arguments.
2937
3038
.. seealso:: If you want to pass the output shape instead, please use
3139
:class:`~torchio.transforms.CropOrPad` instead.
3240
"""
3341

34-
def __init__(self, cropping: TypeBounds, **kwargs):
35-
super().__init__(cropping, **kwargs)
42+
def __init__(self, cropping: TypeBounds, copy=True, **kwargs):
43+
self._copy_patch = copy
44+
# Transform base class deepcopies whole subject by default
45+
# We want to copy only the cropped patch, so we overwrite the functionality
46+
super().__init__(cropping, copy=False, **kwargs)
3647
self.cropping = cropping
3748
self.args_names = ['cropping']
3849

39-
def apply_transform(self, sample) -> Subject:
50+
def apply_transform(self, subject: Subject) -> Subject:
4051
assert self.bounds_parameters is not None
4152
low = self.bounds_parameters[::2]
4253
high = self.bounds_parameters[1::2]
4354
index_ini = low
44-
index_fin = np.array(sample.spatial_shape) - high
45-
for image in self.get_images(sample):
46-
new_origin = nib.affines.apply_affine(image.affine, index_ini)
47-
new_affine = image.affine.copy()
48-
new_affine[:3, 3] = new_origin
49-
i0, j0, k0 = index_ini
50-
i1, j1, k1 = index_fin
55+
index_fin = np.array(subject.spatial_shape) - high
56+
57+
if self._copy_patch:
58+
# Create a clean new subject to copy the images into
59+
# We do this __new__ to avoid calling __init__ so we don't have to specify images immediately
60+
cropped_subject = subject.__class__.__new__(subject.__class__)
61+
image_keys_to_crop = subject.get_images_dict(
62+
intensity_only=False,
63+
include=self.include,
64+
exclude=self.exclude,
65+
).keys()
66+
keys_to_expose = subject.keys()
67+
# Copy all attributes we don't want to crop
68+
# __dict__ returns all attributes, instead of just the images
69+
for key, value in subject.__dict__.items():
70+
if key not in image_keys_to_crop:
71+
copied_value = deepcopy(value)
72+
# Setting __dict__ does not allow key indexing the attribute
73+
# so we set it explicitly if we want to expose it
74+
if key in keys_to_expose:
75+
cropped_subject[key] = copied_value
76+
cropped_subject.__dict__[str(key)] = copied_value
77+
else:
78+
# Images are always exposed, so we don't worry about setting __dict__
79+
cropped_subject[key] = self._crop_image(
80+
value,
81+
index_ini,
82+
index_fin,
83+
copy_patch=self._copy_patch,
84+
)
85+
86+
# Update the __dict__ attribute to include the cropped images
87+
cropped_subject.update_attributes()
88+
return cropped_subject
89+
else:
90+
# Crop in place
91+
for image in self.get_images(subject):
92+
self._crop_image(
93+
image,
94+
index_ini,
95+
index_fin,
96+
copy_patch=self._copy_patch,
97+
)
98+
return subject
99+
100+
@staticmethod
101+
def _crop_image(
102+
image: Image, index_ini: tuple, index_fin: tuple, *, copy_patch: bool
103+
) -> Image:
104+
new_origin = apply_affine(image.affine, index_ini)
105+
new_affine = image.affine.copy()
106+
new_affine[:3, 3] = new_origin
107+
i0, j0, k0 = index_ini
108+
i1, j1, k1 = index_fin
109+
110+
# Crop the image data
111+
if copy_patch:
112+
# Create a new image with the cropped data
113+
cropped_data = image.data[:, i0:i1, j0:j1, k0:k1].clone()
114+
new_image = type(image)(
115+
tensor=cropped_data,
116+
affine=new_affine,
117+
type=image.type,
118+
path=image.path,
119+
)
120+
return new_image
121+
else:
51122
image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone())
52123
image.affine = new_affine
53-
return sample
124+
return image
54125

55126
def inverse(self):
56127
from .pad import Pad

src/torchio/transforms/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class Transform(ABC):
6464
6565
Args:
6666
p: Probability that this transform will be applied.
67-
copy: Make a shallow copy of the input before applying the transform.
67+
copy: Make a deep copy of the input before applying the transform.
6868
include: Sequence of strings with the names of the only images to which
6969
the transform will be applied.
7070
Mandatory if the input is a :class:`dict`.

tests/transforms/preprocessing/test_crop.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,23 @@ def test_tensor_single_channel(self):
1313
def test_tensor_multi_channel(self):
1414
crop = tio.Crop(1)
1515
assert crop(torch.rand(3, 10, 10, 10)).shape == (3, 8, 8, 8)
16+
17+
def test_subject_copy(self):
18+
crop = tio.Crop(1, copy=True)
19+
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
20+
cropped_subject = crop(subject)
21+
assert cropped_subject.t1.shape == (1, 8, 8, 8)
22+
assert subject.t1.shape == (1, 10, 10, 10)
23+
24+
cropped2_subject = crop(cropped_subject)
25+
assert cropped2_subject.t1.shape == (1, 6, 6, 6)
26+
assert cropped_subject.t1.shape == (1, 8, 8, 8)
27+
assert len(cropped2_subject.applied_transforms) == 2
28+
assert len(cropped_subject.applied_transforms) == 1
29+
30+
def test_subject_no_copy(self):
31+
crop = tio.Crop(1, copy=False)
32+
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
33+
cropped_subject = crop(subject)
34+
assert cropped_subject.t1.shape == (1, 8, 8, 8)
35+
assert subject.t1.shape == (1, 8, 8, 8)

0 commit comments

Comments
 (0)