Skip to content

Commit 6dbf4bc

Browse files
committed
FIX: Clean up workflow logic
1 parent 3d4da90 commit 6dbf4bc

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

nibabies/workflows/anatomical/preproc.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,18 @@ def init_csf_norm_wf(name: str = 'csf_norm_wf') -> LiterateWorkflow:
7676
'The CSF mask was used to normalize the anatomical template by the median of voxels '
7777
'within the mask.'
7878
)
79-
inputnode = niu.IdentityInterface(fields=['anat_preproc', 'anat_tpms'], name='inputnode')
80-
outputnode = niu.IdentityInterface(fields=['anat_preproc'], name='outputnode')
79+
inputnode = pe.Node(
80+
niu.IdentityInterface(fields=['anat_preproc', 'anat_tpms']),
81+
name='inputnode',
82+
)
83+
outputnode = pe.Node(niu.IdentityInterface(fields=['anat_preproc']), name='outputnode')
8184

8285
# select CSF from BIDS-ordered list (GM, WM, CSF)
8386
select_csf = pe.Node(niu.Select(index=2), name='select_csf')
8487
norm_csf = pe.Node(niu.Function(function=_normalize_roi), name='norm_csf')
8588

8689
workflow.connect([
87-
(inputnode, select_csf, [('anat_tpms', 'in_list')]),
90+
(inputnode, select_csf, [('anat_tpms', 'inlist')]),
8891
(select_csf, norm_csf, [('out', 'mask_file')]),
8992
(inputnode, norm_csf, [('anat_preproc', 'in_file')]),
9093
(norm_csf, outputnode, [('out', 'anat_preproc')]),
@@ -99,13 +102,14 @@ def _normalize_roi(in_file, mask_file, threshold=0.2, out_file=None):
99102
import numpy as np
100103

101104
img = nb.load(in_file)
102-
img_data = img.get_fdata()
105+
img_data = np.asanyarray(img.dataobj)
103106
mask_img = nb.load(mask_file)
104107
# binary mask
105-
bin_mask = mask_img.get_fdata() > threshold
108+
bin_mask = np.asanyarray(mask_img.dataobj) > threshold
106109
mask_data = bin_mask * img_data
110+
masked_data = mask_data[mask_data > 0]
107111

108-
median = np.median(mask_data[mask_data > 0])
112+
median = np.median(masked_data).astype(masked_data.dtype)
109113
normed_data = np.maximum(img_data, bin_mask * median)
110114

111115
oimg = img.__class__(normed_data, img.affine, img.header)

0 commit comments

Comments
 (0)