Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 63 additions & 10 deletions src/torchio/transforms/preprocessing/spatial/crop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy

import nibabel as nib
import numpy as np

from ....data.image import Image
from ....data.subject import Subject
from .bounds_transform import BoundsTransform
from .bounds_transform import TypeBounds
Expand All @@ -24,33 +27,83 @@ class Crop(BoundsTransform):
If only one value :math:`n` is provided, then
:math:`w_{ini} = w_{fin} = h_{ini} = h_{fin}
= d_{ini} = d_{fin} = n`.
copy: bool, optional
This transform overwrites the copy argument of the base transform and
copies only the cropped patch, instead of the whole image.
If ``True``, the cropped image will be copied to a new subject.
If ``False``, the patch will be cropped in place. Default: ``True``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.

.. seealso:: If you want to pass the output shape instead, please use
:class:`~torchio.transforms.CropOrPad` instead.
"""

def __init__(self, cropping: TypeBounds, **kwargs):
super().__init__(cropping, **kwargs)
def __init__(self, cropping: TypeBounds, copy=True, **kwargs):
self.copy_patch = copy
super().__init__(cropping, copy=False, **kwargs)
self.cropping = cropping
self.args_names = ['cropping']

def apply_transform(self, sample) -> Subject:
def apply_transform(self, sample: Subject) -> Subject:
assert self.bounds_parameters is not None
low = self.bounds_parameters[::2]
high = self.bounds_parameters[1::2]
index_ini = low
index_fin = np.array(sample.spatial_shape) - high
for image in self.get_images(sample):
new_origin = nib.affines.apply_affine(image.affine, index_ini)
new_affine = image.affine.copy()
new_affine[:3, 3] = new_origin
i0, j0, k0 = index_ini
i1, j1, k1 = index_fin

if self.copy_patch:
# Create a new subject with only the cropped patch
sample_attributes = {}

# Copy all non-image attributes
for key, value in sample.items():
if (
key
not in sample.get_images_dict(
intensity_only=False, include=self.include, exclude=self.exclude
).keys()
):
sample_attributes[key] = copy.deepcopy(value)
else:
sample_attributes[key] = self.crop_image(
value, index_ini, index_fin
)
cropped_sample = type(sample)(**sample_attributes)

# Copy applied transforms history
cropped_sample.applied_transforms = copy.deepcopy(sample.applied_transforms)

cropped_sample.update_attributes()
return cropped_sample
else:
# Crop in place
for image in self.get_images(sample):
self.crop_image(image, index_ini, index_fin)
return sample

def crop_image(self, image: Image, index_ini: tuple, index_fin: tuple) -> None:
new_origin = nib.affines.apply_affine(image.affine, index_ini)
new_affine = image.affine.copy()
new_affine[:3, 3] = new_origin
i0, j0, k0 = index_ini
i1, j1, k1 = index_fin

# Crop the image data
if self.copy_patch:
# Create a new image with the cropped data
cropped_data = image.data[:, i0:i1, j0:j1, k0:k1].clone()
new_image = type(image)(
tensor=cropped_data,
affine=new_affine,
type=image.type,
path=image.path,
)
return new_image
else:
image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone())
image.affine = new_affine
return sample
return image

def inverse(self):
from .pad import Pad
Expand Down
2 changes: 1 addition & 1 deletion src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Transform(ABC):
Args:
p: Probability that this transform will be applied.
copy: Make a shallow copy of the input before applying the transform.
copy: Make a deep copy of the input before applying the transform.
include: Sequence of strings with the names of the only images to which
the transform will be applied.
Mandatory if the input is a :class:`dict`.
Expand Down
14 changes: 14 additions & 0 deletions tests/transforms/preprocessing/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,17 @@ def test_tensor_single_channel(self):
def test_tensor_multi_channel(self):
crop = tio.Crop(1)
assert crop(torch.rand(3, 10, 10, 10)).shape == (3, 8, 8, 8)

def test_subject_copy(self):
crop = tio.Crop(1, copy=True)
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
cropped_subject = crop(subject)
assert cropped_subject.t1.shape == (1, 8, 8, 8)
assert subject.t1.shape == (1, 10, 10, 10)

def test_subject_no_copy(self):
crop = tio.Crop(1, copy=False)
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
cropped_subject = crop(subject)
assert cropped_subject.t1.shape == (1, 8, 8, 8)
assert subject.t1.shape == (1, 8, 8, 8)
Loading