-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathsynthstrip.py
More file actions
231 lines (186 loc) · 7.79 KB
/
synthstrip.py
File metadata and controls
231 lines (186 loc) · 7.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# Modified from:
# https://github.com/nipreps/synthstrip/blob/main/nipreps/synthstrip/cli.py
# Original copyright (c) 2024, NiPreps developers
# Licensed under the Apache License, Version 2.0
# Changes made by the BrainLesion Preprocessing team (2025)
from pathlib import Path
from typing import Optional, Union, cast
import nibabel as nib
import numpy as np
import scipy
import torch
from nibabel.nifti1 import Nifti1Image
from nipreps.synthstrip.model import StripModel
from nitransforms.linear import Affine
from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor
from brainles_preprocessing.utils.zenodo import fetch_synthstrip
class SynthStripExtractor(BrainExtractor):
def __init__(self, border: int = 1):
"""
Brain extraction using SynthStrip with preprocessing conforming to model requirements.
This is an optional dependency - to use this extractor, you need to install the `brainles_preprocessing` package with the `synthstrip` extra: `pip install brainles_preprocessing[synthstrip]`
Adapted from https://github.com/nipreps/synthstrip
Args:
border (int): Mask border threshold in mm. Defaults to 1.
"""
super().__init__()
self.border = border
def _setup_model(self, device: torch.device) -> StripModel:
"""
Load SynthStrip model and prepare it for inference on the specified device.
Args:
device: Device to load the model onto.
Returns:
A configured and ready-to-use StripModel.
"""
# necessary for speed gains (according to original nipreps authors)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
with torch.no_grad():
model = StripModel()
model.to(device)
model.eval()
# Load the model weights
weights_folder = fetch_synthstrip()
weights = weights_folder / "synthstrip.1.pt"
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
return model
def _conform(self, input_nii: Nifti1Image) -> Nifti1Image:
"""
Resample the input image to match SynthStrip's expected input space.
Args:
input_nii (Nifti1Image): Input NIfTI image to conform.
Raises:
ValueError: If the input NIfTI image does not have a valid affine.
Returns:
A new NIfTI image with conformed shape and affine.
"""
shape = np.array(input_nii.shape[:3])
affine = input_nii.affine
if affine is None:
raise ValueError("Input NIfTI image must have a valid affine.")
# Get corner voxel centers in index coords
corner_centers_ijk = (
np.array(
[
(i, j, k)
for k in (0, shape[2] - 1)
for j in (0, shape[1] - 1)
for i in (0, shape[0] - 1)
]
)
+ 0.5
)
# Get corner voxel centers in mm
corners_xyz = (
affine
@ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T
)
# Target affine is 1mm voxels in LIA orientation
target_affine = np.diag([-1.0, 1.0, -1.0, 1.0])[:, (0, 2, 1, 3)]
# Target shape
extent = corners_xyz.min(1)[:3], corners_xyz.max(1)[:3]
target_shape = ((extent[1] - extent[0]) / 1.0 + 0.999).astype(int)
# SynthStrip likes dimensions be multiple of 64 (192, 256, or 320)
target_shape = np.clip(
np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320
)
# Ensure shape ordering is LIA too
target_shape[2], target_shape[1] = target_shape[1:3]
# Coordinates of center voxel do not change
input_c = affine @ np.hstack((0.5 * (shape - 1), 1.0))
target_c = target_affine @ np.hstack((0.5 * (target_shape - 1), 1.0))
# Rebase the origin of the new, plumb affine
target_affine[:3, 3] -= target_c[:3] - input_c[:3]
nii = Affine(
reference=Nifti1Image(
np.zeros(target_shape),
target_affine,
None,
),
).apply(input_nii)
return cast(Nifti1Image, nii)
def _resample_like(
self,
image: Nifti1Image,
target: Nifti1Image,
output_dtype: Optional[np.dtype] = None,
cval: Union[int, float] = 0,
) -> Nifti1Image:
"""
Resample the input image to match the target's grid using an identity transform.
Args:
image: The image to be resampled.
target: The reference image.
output_dtype: Output data type.
cval: Value to use for constant padding.
Returns:
A resampled NIfTI image.
"""
result = Affine(reference=target).apply(
image,
output_dtype=output_dtype,
cval=cval,
)
return cast(Nifti1Image, result)
def extract(
self,
input_image_path: Union[str, Path],
masked_image_path: Union[str, Path],
brain_mask_path: Union[str, Path],
device: Union[torch.device, str] = "cuda",
num_threads: int = 1,
**kwargs,
) -> None:
"""
Extract the brain from an input image using SynthStrip.
Args:
input_image_path (Union[str, Path]): Path to the input image.
masked_image_path (Union[str, Path]): Path to the output masked image.
brain_mask_path (Union[str, Path]): Path to the output brain mask.
device (Union[torch.device, str], optional): Device to use for computation. Defaults to "cuda".
num_threads (int, optional): Number of threads to use for computation in CPU mode. Defaults to 1.
Returns:
None: The function saves the masked image and brain mask to the specified paths.
"""
device = torch.device(device) if isinstance(device, str) else device
model = self._setup_model(device=device)
if device.type == "cpu" and num_threads > 0:
torch.set_num_threads(num_threads)
# normalize intensities
image = nib.load(input_image_path)
image = cast(Nifti1Image, image)
conformed = self._conform(image)
in_data = conformed.get_fdata(dtype="float32")
in_data -= in_data.min()
in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1)
in_data = in_data[np.newaxis, np.newaxis]
# predict the surface distance transform
input_tensor = torch.from_numpy(in_data).to(device)
with torch.no_grad():
sdt = model(input_tensor).cpu().numpy().squeeze()
# unconform the sdt and extract mask
sdt_target = self._resample_like(
Nifti1Image(sdt, conformed.affine, None),
image,
output_dtype=np.dtype("int16"),
cval=100,
)
sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16")
# find largest CC (just do this to be safe for now)
components = scipy.ndimage.label(sdt_data.squeeze() < self.border)[0]
bincount = np.bincount(components.flatten())[1:]
mask = components == (np.argmax(bincount) + 1)
mask = scipy.ndimage.morphology.binary_fill_holes(mask)
# write the masked output
img_data = image.get_fdata()
bg = np.min([0, img_data.min()])
img_data[mask == 0] = bg
Nifti1Image(img_data, image.affine, image.header).to_filename(
masked_image_path,
)
# write the brain mask
hdr = image.header.copy()
hdr.set_data_dtype("uint8")
Nifti1Image(mask, image.affine, hdr).to_filename(brain_mask_path)