Skip to content

Commit af4d5f4

Browse files
committed
ENH: improve pet mask generation
1 parent 9039aaa commit af4d5f4

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

fmriprep/workflows/pet/confounds.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,12 +580,12 @@ def _binary_union(mask1, mask2):
580580

581581

582582
def _smooth_binarize(in_file, fwhm=10.0, thresh=0.2):
583-
"""Smooth ``in_file`` with a Gaussian kernel and binarize at ``thresh``."""
583+
"""Smooth ``in_file`` with a Gaussian kernel, binarize and keep largest cluster."""
584584
from pathlib import Path
585585

586586
import nibabel as nb
587587
import numpy as np
588-
from scipy.ndimage import gaussian_filter
588+
from scipy.ndimage import gaussian_filter, label
589589

590590
img = nb.load(in_file)
591591
data = img.get_fdata(dtype=np.float32)
@@ -594,6 +594,12 @@ def _smooth_binarize(in_file, fwhm=10.0, thresh=0.2):
594594
smoothed = gaussian_filter(data, sigma=sigma)
595595
mask = smoothed > (thresh * smoothed.max())
596596

597+
labeled, n_labels = label(mask)
598+
if n_labels > 1:
599+
sizes = np.bincount(labeled.ravel())
600+
sizes[0] = 0 # ignore background
601+
mask = labeled == sizes.argmax()
602+
597603
out_img = img.__class__(mask.astype('uint8'), img.affine, img.header)
598604
out_img.set_data_dtype('uint8')
599605
out_name = Path('smoothed_bin_mask.nii.gz').absolute()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import nibabel as nb
2+
import numpy as np
3+
from scipy.ndimage import label
4+
5+
from ..confounds import _smooth_binarize
6+
7+
8+
def test_smooth_binarize_largest(tmp_path):
9+
data = np.zeros((5, 5, 5))
10+
data[1:3, 1:3, 1:3] = 1
11+
data[4, 4, 4] = 1
12+
img = nb.Nifti1Image(data, np.eye(4))
13+
src = tmp_path / 'input.nii.gz'
14+
img.to_filename(src)
15+
16+
out = _smooth_binarize(str(src), fwhm=0.0, thresh=0.5)
17+
result = nb.load(out).get_fdata()
18+
_, num = label(result > 0)
19+
assert num == 1

0 commit comments

Comments
 (0)