Skip to content

Commit 306578e

Browse files
Merge pull request #76 from ChristianHinge/dev/registration
integrate niftyreg aladin affine registration
2 parents 14075b5 + 6470668 commit 306578e

File tree

12 files changed

+500
-80
lines changed

12 files changed

+500
-80
lines changed

poetry.lock

Lines changed: 13 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ click = ">=8.0.1"
2121
torch = "^1.13.1"
2222
torchio = "^0.18.87"
2323
nibabel = "^5.0.1"
24+
niftyreg = "^1.5.70rc1"
2425

2526
[tool.poetry.dev-dependencies]
2627
Pygments = ">=2.10.0"

src/zerodose/__main__.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def main() -> None:
3333
"cuda:7",
3434
]
3535
),
36-
default="cpu",
36+
default="cuda:0",
3737
help="Device to use for inference.",
3838
)
3939

@@ -69,27 +69,46 @@ def main() -> None:
6969
help="Print verbose output.",
7070
)
7171

72+
no_registration_option = click.option(
73+
"-n",
74+
"--no-registration",
75+
"no_registration",
76+
is_flag=True,
77+
default=False,
78+
help="""Skip registration to MNI space.
79+
Useful if the input images already are in MNI space""",
80+
)
81+
7282

7383
@main.command()
7484
@mri_option
7585
@mask_option
7686
@sbpet_output_option
7787
@verbose_option
7888
@device_option
89+
@no_registration_option
7990
def syn(
8091
mri_fnames: Sequence[str],
8192
mask_fnames: Sequence[str],
8293
out_fnames: Union[Sequence[str], None] = None,
83-
verbose: bool = False,
94+
verbose: bool = True,
8495
device: str = "cuda:0",
96+
no_registration: bool = False,
8597
) -> None:
8698
"""Synthesize baseline PET images."""
8799
if out_fnames is None or len(out_fnames) == 0:
88100
out_fnames = [
89101
_create_output_fname(mri_fname, suffix="_sb") for mri_fname in mri_fnames
90102
]
103+
104+
do_registration = not no_registration
91105
synthesize_baselines(
92-
mri_fnames, mask_fnames, out_fnames, verbose=verbose, device=device
106+
mri_fnames,
107+
mask_fnames,
108+
out_fnames,
109+
verbose=verbose,
110+
device=device,
111+
do_registration=do_registration,
93112
)
94113

95114

src/zerodose/dataset.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
"""Dataset class for the ZeroDose project."""
22
from typing import Any
33
from typing import Dict
4+
from typing import List
45

56
import torchio as tio
67

7-
from zerodose.processing import Pad
8+
from zerodose import utils
9+
from zerodose.niftyreg_wrapper import NiftyRegistration
10+
from zerodose.processing import PadAndCropToMNI
811
from zerodose.processing import ToFloat32
912

1013

1114
class SubjectDataset(tio.data.SubjectsDataset):
1215
"""Dataset class for the ZeroDose project."""
1316

14-
def __init__(self, mri_fnames, mask_fnames, out_fnames):
17+
def __init__(self, mri_fnames, mask_fnames, out_fnames, do_registration=True):
1518
"""Initialize the dataset."""
16-
transforms = self._get_augmentation_transform_val()
19+
transforms = self._get_augmentation_transform_val(
20+
do_registration=do_registration
21+
)
1722
subjects = [
1823
self._make_subject_predict(mr_f, ma_f, ou_f)
1924
for mr_f, ma_f, ou_f in zip( # noqa
@@ -41,11 +46,19 @@ def _make_subject_predict(self, mr_path, mask_path, out_fname) -> tio.Subject:
4146

4247
return tio.Subject(subject_dict)
4348

44-
def _get_augmentation_transform_val(self) -> tio.Compose:
45-
return tio.Compose(
49+
def _get_augmentation_transform_val(self, do_registration=True) -> tio.Compose:
50+
augmentations: List[tio.Transform] = []
51+
52+
if do_registration:
53+
ref = utils.get_mni_template()
54+
augmentations.append(NiftyRegistration(floating_image="mr", ref=ref))
55+
56+
augmentations.extend(
4657
[
4758
tio.transforms.ZNormalization(include=["mr"], masking_method="mask"),
48-
Pad(include=["mr", "mask"]),
59+
PadAndCropToMNI(include=["mr", "mask"]),
4960
ToFloat32(include=["mr"]),
5061
]
5162
)
63+
64+
return tio.Compose(augmentations)

src/zerodose/niftyreg_wrapper.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""Niftyreg wrapper for torchio."""
2+
import os
3+
import tempfile
4+
5+
import nibabel as nib
6+
import niftyreg # type: ignore
7+
import numpy as np
8+
import torch
9+
from torchio import Subject
10+
from torchio.data import io
11+
from torchio.transforms import SpatialTransform
12+
13+
14+
def _save_matrix_nifty(affine_mat, file_name):
15+
s = ""
16+
for i in range(affine_mat.shape[0]):
17+
for j in range(affine_mat.shape[1]):
18+
affine_mat[i, j] = affine_mat[i, j].item()
19+
s += str(affine_mat[i, j]) + " "
20+
s = s[:-1]
21+
s += "\n"
22+
with open(file_name, "w") as f:
23+
f.write(s)
24+
25+
26+
def _read_matrix_nifty(file_name):
27+
with open(file_name) as f:
28+
s = f.read()
29+
s = s.split("\n")
30+
s = s[:-1]
31+
s = [i.split(" ") for i in s]
32+
s = [[float(j) for j in i] for i in s]
33+
s = np.array(s)
34+
return s
35+
36+
37+
def _register_mri_to_mni(mri_fname, ref):
38+
39+
temp_aff = tempfile.NamedTemporaryFile(delete=False)
40+
41+
out_mri_fname = os.path.dirname(mri_fname) + "/out_mri1.nii.gz"
42+
mni_template = ref
43+
niftyreg.main(
44+
[
45+
"aladin",
46+
"-flo",
47+
mri_fname,
48+
"-ref",
49+
mni_template,
50+
"-res",
51+
out_mri_fname,
52+
"-aff",
53+
temp_aff.name,
54+
"-speeeeed",
55+
]
56+
)
57+
affine_mat = _read_matrix_nifty(temp_aff.name)
58+
temp_aff.close()
59+
os.remove(temp_aff.name)
60+
return affine_mat
61+
62+
63+
def _nifty_reg_resample(ref_path, flo_img, affine_mat):
64+
65+
temp_flo = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz")
66+
temp_res = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz")
67+
temp_aff = tempfile.NamedTemporaryFile(delete=False)
68+
69+
io.write_image(flo_img.tensor, flo_img.affine, temp_flo.name)
70+
_save_matrix_nifty(affine_mat, temp_aff.name)
71+
72+
niftyreg.main(
73+
[
74+
"resample",
75+
"-ref",
76+
ref_path,
77+
"-flo",
78+
temp_flo.name,
79+
"-res",
80+
temp_res.name,
81+
"-aff",
82+
temp_aff.name,
83+
]
84+
)
85+
86+
# After the other process has finished writing to the files, read from them
87+
res = nib.load(temp_res.name)
88+
data = res.get_fdata()
89+
affine = res.affine
90+
91+
temp_flo.close()
92+
temp_res.close()
93+
temp_aff.close()
94+
95+
os.remove(temp_flo.name)
96+
os.remove(temp_res.name)
97+
os.remove(temp_aff.name)
98+
99+
return data, affine
100+
101+
102+
class NiftyRegistration(SpatialTransform):
103+
"""Nifty Registration for torchio."""
104+
105+
def __init__(
106+
self,
107+
floating_image=None,
108+
ref=None,
109+
**kwargs,
110+
):
111+
"""Initialize the niftyreg registration transform."""
112+
self.floating_image = floating_image
113+
super().__init__(**kwargs)
114+
self.ref = ref
115+
116+
def apply_transform(self, subject: Subject) -> Subject:
117+
"""Apply the registration and coregistration."""
118+
temp_flo = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz")
119+
120+
io.write_image(
121+
subject[self.floating_image].tensor,
122+
subject[self.floating_image].affine,
123+
temp_flo.name,
124+
)
125+
126+
affine = _register_mri_to_mni(temp_flo.name, self.ref)
127+
128+
temp_flo.close()
129+
os.remove(temp_flo.name)
130+
inverse_ref = subject[self.floating_image].path
131+
transformer = NiftyResample(
132+
ref=self.ref, affine=affine, inverse_ref=inverse_ref
133+
)
134+
transformed = transformer(subject)
135+
return transformed # type: ignore
136+
137+
138+
class NiftyResample(SpatialTransform):
139+
"""Nifty resample (coreregistration)."""
140+
141+
def __init__(self, affine, ref=None, is_inverse=False, inverse_ref=None, **kwargs):
142+
"""Initialize the nifty resample."""
143+
self.affine = affine
144+
self.ref = ref
145+
self.is_inverse = is_inverse
146+
self.inverse_ref = inverse_ref
147+
super().__init__(**kwargs)
148+
self.args_names = ("affine", "ref", "is_inverse", "inverse_ref")
149+
150+
def apply_transform(self, subject: Subject) -> Subject:
151+
"""Apply the transform."""
152+
for image in self.get_images(subject):
153+
if self.is_inverse:
154+
if image.path is not None:
155+
ref = image.path
156+
else:
157+
ref = self.inverse_ref
158+
_apply_niftyreg_resample(image, ref, np.linalg.inv(self.affine))
159+
else:
160+
_apply_niftyreg_resample(image, self.ref, self.affine)
161+
return subject
162+
163+
@staticmethod
164+
def is_invertible():
165+
"""Whether the transform is invertible."""
166+
return True
167+
168+
def inverse(self):
169+
"""Return the inverse resample."""
170+
return NiftyResample(
171+
affine=self.affine, is_inverse=True, inverse_ref=self.inverse_ref
172+
)
173+
174+
175+
def _apply_niftyreg_resample(image, ref_path, affine_mat):
176+
data, affine = _nifty_reg_resample(ref_path, image, affine_mat)
177+
image.affine = affine
178+
data = data.copy()
179+
data = torch.as_tensor(data)
180+
image.set_data(data.unsqueeze(0))

src/zerodose/paths.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,9 @@
55

66

77
# please refer to the readme on where to get the parameters. Save them in this folder:
8-
folder_with_parameter_files = os.path.join(os.path.expanduser("~"), ".zerodose_params")
8+
folder_with_parameter_files = os.path.join(
9+
os.path.expanduser("~"), ".zerodose_data", "model_params"
10+
)
11+
folder_with_templates = os.path.join(
12+
os.path.expanduser("~"), ".zerodose_data", "templates"
13+
)

src/zerodose/processing.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,44 @@ def _crop_192_to_mni(arr: torch.Tensor) -> torch.Tensor:
3838
return parr
3939

4040

41-
def postprocess(arr: torch.Tensor) -> torch.Tensor:
42-
"""Postprocess the 192x192x192 image to MNI."""
43-
return _crop_192_to_mni(arr)
44-
45-
46-
class Pad(SpatialTransform):
41+
class PadAndCropToMNI(SpatialTransform):
4742
"""Pad the MNI image to 192x192x192."""
4843

49-
def __init__(self, **kwargs) -> None:
44+
def __init__(self, is_inverse=False, **kwargs) -> None:
5045
"""Initialize the transform."""
46+
self.is_inverse = is_inverse
5147
super().__init__(**kwargs)
5248

5349
def apply_transform(self, subject: Subject) -> Subject:
5450
"""Apply the transform to the subject."""
5551
for image in self.get_images(subject):
56-
_pad(image)
52+
if self.is_inverse:
53+
_pad_inv(image)
54+
else:
55+
_pad(image)
5756

5857
return subject
5958

6059
@staticmethod
6160
def is_invertible() -> bool:
6261
"""Return whether the transform is invertible."""
63-
return False
62+
return True
63+
64+
def inverse(self):
65+
"""Returns the inverse transform."""
66+
return PadAndCropToMNI(is_inverse=True)
6467

6568

6669
def _pad(image: tio.Image) -> None:
6770
data = processing._crop_mni_to_192(image.data)
6871
image.set_data(data)
6972

7073

74+
def _pad_inv(image: tio.Image) -> None:
75+
data = processing._crop_192_to_mni(image.data)
76+
image.set_data(data)
77+
78+
7179
class ToFloat32(SpatialTransform):
7280
"""Convert the image to float32."""
7381

0 commit comments

Comments
 (0)