Skip to content

Commit 21899c6

Browse files
author
Christian
committed
FIX: binarize mask
1 parent 000fd42 commit 21899c6

File tree

4 files changed

+28
-5
lines changed

4 files changed

+28
-5
lines changed

src/zerodose/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from zerodose import utils
99
from zerodose.niftyreg_wrapper import NiftyRegistration
10+
from zerodose.processing import Binarize
1011
from zerodose.processing import PadAndCropToMNI
1112
from zerodose.processing import ToFloat32
1213

@@ -55,6 +56,7 @@ def _get_augmentation_transform_val(self, do_registration=True) -> tio.Compose:
5556

5657
augmentations.extend(
5758
[
59+
Binarize(include=["mask"]),
5860
tio.transforms.ZNormalization(include=["mr"], masking_method="mask"),
5961
PadAndCropToMNI(include=["mr", "mask"]),
6062
ToFloat32(include=["mr"]),

src/zerodose/niftyreg_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def _read_matrix_nifty(file_name):
3636

3737
def _register_mri_to_mni(mri_fname, ref):
3838
temp_aff = tempfile.NamedTemporaryFile(delete=False)
39-
40-
out_mri_fname = os.path.dirname(mri_fname) + "/out_mri1.nii.gz"
39+
temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz")
40+
out_mri_fname = temp_out.name
4141
mni_template = ref
4242
niftyreg.main(
4343
[
@@ -55,7 +55,9 @@ def _register_mri_to_mni(mri_fname, ref):
5555
)
5656
affine_mat = _read_matrix_nifty(temp_aff.name)
5757
temp_aff.close()
58+
temp_out.close()
5859
os.remove(temp_aff.name)
60+
os.remove(temp_out.name)
5961
return affine_mat
6062

6163

src/zerodose/processing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@ def inverse(self):
6666
return PadAndCropToMNI(is_inverse=True)
6767

6868

69+
class Binarize(SpatialTransform):
70+
"""Binarize an image based on a threshhold."""
71+
72+
def __init__(self, threshold=0.5, **kwargs) -> None:
73+
"""Initialize the transform."""
74+
self.threshold = threshold
75+
super().__init__(**kwargs)
76+
77+
def apply_transform(self, subject: Subject) -> Subject:
78+
"""Apply the transform to the subject."""
79+
for image in self.get_images(subject):
80+
data = image.data
81+
image.set_data(data > self.threshold)
82+
return subject
83+
84+
6985
def _pad(image: tio.Image) -> None:
7086
data = processing._crop_mni_to_192(image.data)
7187
image.set_data(data)

tests/test_slow.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _get_mni_dir():
2424
return mni_dir
2525

2626

27-
def _augment_mni_img_for_tests(image_in_path, image_out_path):
27+
def _augment_mni_img_for_tests(image_in_path, image_out_path, binarize=False):
2828
rad = np.deg2rad(10)
2929
cos_gamma = np.cos(rad)
3030
sin_gamma = np.sin(rad)
@@ -51,6 +51,9 @@ def _augment_mni_img_for_tests(image_in_path, image_out_path):
5151
after_rot = resample_from_to(
5252
img_in, ((190, 200, 170), affine_mat.dot(img_in.affine))
5353
)
54+
# if binarize:
55+
# data = after_rot.get_fdata()
56+
# data.get
5457

5558
ornt = np.array([[0, 1], [1, 1], [2, 1]])
5659

@@ -164,7 +167,7 @@ def sbpet_outputfile():
164167
return fn
165168

166169

167-
@pytest.mark.slow
170+
# @pytest.mark.slow
168171
def test_syn_mni(runner, mri_mni_file, mask_mni_file, sbpet_outputfile) -> None:
169172
"""Test the syn command with the standard model and MNI files."""
170173
result = runner.invoke(
@@ -199,4 +202,4 @@ def test_syn_niftyreg(mri_aug_file, mask_aug_file, sbpet_outputfile) -> None:
199202
device="cuda:0",
200203
)
201204

202-
assert nib.load(mri_aug_file).get_fdata().shape == start_shape
205+
assert nib.load(sbpet_outputfile).get_fdata().shape == start_shape

0 commit comments

Comments
 (0)