Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions src/torchio/transforms/preprocessing/spatial/crop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from copy import deepcopy

import numpy as np
from nibabel.affines import apply_affine

from ....data.image import Image
from ....data.subject import Subject
from .bounds_transform import BoundsTransform
from .bounds_transform import TypeBounds
Expand All @@ -24,15 +27,23 @@ class Crop(BoundsTransform):
If only one value :math:`n` is provided, then
:math:`w_{ini} = w_{fin} = h_{ini} = h_{fin}
= d_{ini} = d_{fin} = n`.
copy: If ``True``, each image will be cropped and the patch copied to a new
subject. If ``False``, each image will be cropped in place. This transform
overwrites the copy argument of the base transform and copies only the
cropped patch instead of the whole image. This can provide a significant
speedup when cropping small patches from large images.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.

.. seealso:: If you want to pass the output shape instead, please use
:class:`~torchio.transforms.CropOrPad` instead.
"""

def __init__(self, cropping: TypeBounds, **kwargs):
super().__init__(cropping, **kwargs)
def __init__(self, cropping: TypeBounds, copy=True, **kwargs):
self._copy_patch = copy
# Transform base class deepcopies whole subject by default
# We want to copy only the cropped patch, so we overwrite the functionality
super().__init__(cropping, copy=False, **kwargs)
self.cropping = cropping
self.args_names = ['cropping']

Expand All @@ -42,15 +53,75 @@ def apply_transform(self, subject: Subject) -> Subject:
high = self.bounds_parameters[1::2]
index_ini = low
index_fin = np.array(subject.spatial_shape) - high
for image in self.get_images(subject):
new_origin = apply_affine(image.affine, index_ini)
new_affine = image.affine.copy()
new_affine[:3, 3] = new_origin
i0, j0, k0 = index_ini
i1, j1, k1 = index_fin

if self._copy_patch:
# Create a clean new subject to copy the images into
# We do this __new__ to avoid calling __init__ so we don't have to specify images immediately
cropped_subject = subject.__class__.__new__(subject.__class__)
image_keys_to_crop = subject.get_images_dict(
intensity_only=False,
include=self.include,
exclude=self.exclude,
).keys()
keys_to_expose = subject.keys()
# Copy all attributes we don't want to crop
# __dict__ returns all attributes, instead of just the images
for key, value in subject.__dict__.items():
if key not in image_keys_to_crop:
copied_value = deepcopy(value)
# Setting __dict__ does not allow key indexing the attribute
# so we set it explicitly if we want to expose it
if key in keys_to_expose:
cropped_subject[key] = copied_value
cropped_subject.__dict__[str(key)] = copied_value
else:
# Images are always exposed, so we don't worry about setting __dict__
cropped_subject[key] = self._crop_image(
value,
index_ini,
index_fin,
copy_patch=self._copy_patch,
)

# Update the __dict__ attribute to include the cropped images
cropped_subject.update_attributes()
return cropped_subject
else:
# Crop in place
for image in self.get_images(subject):
self._crop_image(
image,
index_ini,
index_fin,
copy_patch=self._copy_patch,
)
return subject

@staticmethod
def _crop_image(
image: Image, index_ini: tuple, index_fin: tuple, *, copy_patch: bool
) -> Image:
new_origin = apply_affine(image.affine, index_ini)
new_affine = image.affine.copy()
new_affine[:3, 3] = new_origin
i0, j0, k0 = index_ini
i1, j1, k1 = index_fin

# Crop the image data
if copy_patch:
# Create a new image with the cropped data
cropped_data = image.data[:, i0:i1, j0:j1, k0:k1].clone()
new_image = type(image)(
tensor=cropped_data,
affine=new_affine,
type=image.type,
path=image.path,
)
return new_image
else:
image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone())
image.affine = new_affine
return subject
return image

def inverse(self):
from .pad import Pad
Expand Down
20 changes: 20 additions & 0 deletions tests/transforms/preprocessing/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,23 @@ def test_tensor_single_channel(self):
def test_tensor_multi_channel(self):
crop = tio.Crop(1)
assert crop(torch.rand(3, 10, 10, 10)).shape == (3, 8, 8, 8)

def test_subject_copy(self):
crop = tio.Crop(1, copy=True)
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
cropped_subject = crop(subject)
assert cropped_subject.t1.shape == (1, 8, 8, 8)
assert subject.t1.shape == (1, 10, 10, 10)

cropped2_subject = crop(cropped_subject)
assert cropped2_subject.t1.shape == (1, 6, 6, 6)
assert cropped_subject.t1.shape == (1, 8, 8, 8)
assert len(cropped2_subject.applied_transforms) == 2
assert len(cropped_subject.applied_transforms) == 1

def test_subject_no_copy(self):
crop = tio.Crop(1, copy=False)
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
cropped_subject = crop(subject)
assert cropped_subject.t1.shape == (1, 8, 8, 8)
assert subject.t1.shape == (1, 8, 8, 8)
Loading