Skip to content

Commit 5c3716b

Browse files
committed
ENH: Refactor PET workflow segmentation process
1 parent 4311dcd commit 5c3716b

File tree

4 files changed

+102
-37
lines changed

4 files changed

+102
-37
lines changed

petprep/workflows/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def init_single_subject_wf(subject_id: str):
164164
)
165165

166166
from petprep.workflows.pet.base import init_pet_wf
167+
from petprep.workflows.pet.segmentation import init_segmentation_wf
167168

168169
workflow = Workflow(name=f'sub_{subject_id}_wf')
169170
workflow.__desc__ = f"""
@@ -526,6 +527,24 @@ def init_single_subject_wf(subject_id: str):
526527
if config.workflow.anat_only:
527528
return clean_datasinks(workflow)
528529

530+
segmentation_wf = init_segmentation_wf(
531+
seg=config.workflow.seg,
532+
name=f'pet_{config.workflow.seg}_seg_wf',
533+
)
534+
workflow.connect(
535+
[
536+
(
537+
anat_fit_wf,
538+
segmentation_wf,
539+
[
540+
('outputnode.t1w_preproc', 'inputnode.t1w_preproc'),
541+
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
542+
('outputnode.subject_id', 'inputnode.subject_id'),
543+
],
544+
),
545+
]
546+
)
547+
529548
# Append the PET section to the existing anatomical excerpt
530549
# That way we do not need to filter down the number of PET datasets
531550
pet_pre_desc = f"""
@@ -600,6 +619,10 @@ def init_single_subject_wf(subject_id: str):
600619
'inputnode.sphere_reg_fsLR',
601620
),
602621
]),
622+
(segmentation_wf, pet_wf, [
623+
('outputnode.segmentation', 'inputnode.segmentation'),
624+
('outputnode.dseg_tsv', 'inputnode.dseg_tsv'),
625+
]),
603626
]) # fmt:skip
604627

605628
if config.workflow.level == 'full':

petprep/workflows/pet/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def init_pet_wf(
115115
Registration spheres from fsnative to fsLR space, collated left, then right
116116
anat_ribbon
117117
Binary cortical ribbon mask in T1w space
118+
segmentation
119+
Segmentation file in T1w space
120+
dseg_tsv
121+
TSV with segmentation statistics
118122
anat2std_xfm
119123
Transform from anatomical space to standard space
120124
std_t1w
@@ -206,6 +210,8 @@ def init_pet_wf(
206210
'midthickness_fsLR',
207211
'cortex_mask',
208212
'anat_ribbon',
213+
'segmentation',
214+
'dseg_tsv',
209215
# Volumetric templates
210216
'anat2std_xfm',
211217
'std_t1w',
@@ -242,6 +248,8 @@ def init_pet_wf(
242248
('subjects_dir', 'inputnode.subjects_dir'),
243249
('subject_id', 'inputnode.subject_id'),
244250
('fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
251+
('segmentation', 'inputnode.segmentation'),
252+
('dseg_tsv', 'inputnode.dseg_tsv'),
245253
]),
246254
]) # fmt:skip
247255

petprep/workflows/pet/fit.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from .ref_tacs import init_pet_ref_tacs_wf
5252
from .reference_mask import init_pet_refmask_wf
5353
from .registration import init_pet_reg_wf
54-
from .segmentation import init_segmentation_wf
5554

5655

5756
def init_pet_fit_wf(
@@ -103,6 +102,10 @@ def init_pet_fit_wf(
103102
FreeSurfer subject ID
104103
fsnative2t1w_xfm
105104
LTA-style affine matrix translating from FreeSurfer-conformed subject space to T1w
105+
segmentation
106+
Segmentation file in T1w space
107+
dseg_tsv
108+
TSV with segmentation statistics
106109
107110
Outputs
108111
-------
@@ -116,6 +119,10 @@ def init_pet_fit_wf(
116119
petref2anat_xfm
117120
Affine transform mapping from PET reference space to the anatomical
118121
space.
122+
segmentation
123+
Segmentation file in T1w space
124+
dseg_tsv
125+
TSV with segmentation statistics
119126
120127
See Also
121128
--------
@@ -169,6 +176,8 @@ def init_pet_fit_wf(
169176
'subjects_dir',
170177
'subject_id',
171178
'fsnative2t1w_xfm',
179+
'segmentation',
180+
'dseg_tsv',
172181
],
173182
),
174183
name='inputnode',
@@ -192,6 +201,11 @@ def init_pet_fit_wf(
192201

193202
# If all derivatives exist, inputnode could go unconnected, so add explicitly
194203
workflow.add_nodes([inputnode])
204+
workflow.connect(
205+
[
206+
(inputnode, outputnode, [('segmentation', 'segmentation'), ('dseg_tsv', 'dseg_tsv')]),
207+
]
208+
)
195209

196210
petref_buffer = pe.Node(
197211
niu.IdentityInterface(fields=['petref', 'pet_file']),
@@ -397,41 +411,10 @@ def init_pet_fit_wf(
397411
ds_petmask_wf.inputs.inputnode.source_files = [pet_file]
398412
workflow.connect([(merge_mask, ds_petmask_wf, [('out', 'inputnode.petmask')])])
399413

400-
# Stage 4: Segmentation
401-
config.loggers.workflow.info(
402-
'PET Stage 4: Adding segmentation workflow using the segmentation: %s', config.workflow.seg
403-
)
404-
segmentation_wf = init_segmentation_wf(
405-
seg=config.workflow.seg,
406-
name=f'pet_{config.workflow.seg}_seg_wf',
407-
)
408-
409-
workflow.connect(
410-
[
411-
(
412-
inputnode,
413-
segmentation_wf,
414-
[
415-
('t1w_preproc', 'inputnode.t1w_preproc'),
416-
('subject_id', 'inputnode.subject_id'),
417-
('subjects_dir', 'inputnode.subjects_dir'),
418-
],
419-
),
420-
(
421-
segmentation_wf,
422-
outputnode,
423-
[
424-
('outputnode.segmentation', 'segmentation'),
425-
('outputnode.dseg_tsv', 'dseg_tsv'),
426-
],
427-
),
428-
]
429-
)
430-
431-
# Stage 5: Reference mask generation
414+
# Stage 4: Reference mask generation
432415
if config.workflow.ref_mask_name:
433416
config.loggers.workflow.info(
434-
'PET Stage 5: Generating %s reference mask',
417+
'PET Stage 4: Generating %s reference mask',
435418
config.workflow.ref_mask_name,
436419
)
437420

@@ -487,10 +470,10 @@ def init_pet_fit_wf(
487470
workflow.connect(
488471
[
489472
(
490-
segmentation_wf,
473+
inputnode,
491474
refmask_wf,
492475
[
493-
('outputnode.segmentation', 'inputnode.seg_file'),
476+
('segmentation', 'inputnode.seg_file'),
494477
],
495478
),
496479
(

petprep/workflows/tests/test_base.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pathlib import Path
55
from unittest.mock import patch
66

7+
from pathlib import Path
8+
79
import nibabel as nb
810
import numpy as np
911
import pytest
@@ -13,7 +15,7 @@
1315
from niworkflows.utils.testing import generate_bids_skeleton
1416

1517
from ... import config
16-
from ..base import init_petprep_wf
18+
from ..base import init_petprep_wf, init_single_subject_wf
1719
from ..tests import mock_config
1820

1921
BASE_LAYOUT = {
@@ -96,6 +98,55 @@ def bids_root(tmp_path_factory):
9698
return bids_dir
9799

98100

101+
@pytest.fixture(scope='module')
102+
def multisession_bids_root(tmp_path_factory):
103+
base = tmp_path_factory.mktemp('multisession')
104+
bids_dir = base / 'bids'
105+
bids_dir.mkdir(parents=True, exist_ok=True)
106+
img = nb.Nifti1Image(np.zeros((10, 10, 10, 10)), np.eye(4))
107+
(bids_dir / 'dataset_description.json').write_text(
108+
'{"Name": "Test", "BIDSVersion": "1.8.0"}'
109+
)
110+
for ses in ['01', '02']:
111+
anat_dir = bids_dir / 'sub-01' / f'ses-{ses}' / 'anat'
112+
pet_dir = bids_dir / 'sub-01' / f'ses-{ses}' / 'pet'
113+
anat_dir.mkdir(parents=True, exist_ok=True)
114+
pet_dir.mkdir(parents=True, exist_ok=True)
115+
img.to_filename(anat_dir / f'sub-01_ses-{ses}_T1w.nii.gz')
116+
pet_path = pet_dir / f'sub-01_ses-{ses}_task-rest_run-1_pet.nii.gz'
117+
img.to_filename(pet_path)
118+
(pet_path.with_suffix('').with_suffix('.json')).write_text(
119+
'{"FrameTimesStart": [0], "FrameDuration": [1]}'
120+
)
121+
return bids_dir
122+
123+
124+
def test_segmentation_shared_across_runs(multisession_bids_root):
125+
with mock_config(bids_dir=multisession_bids_root):
126+
wf = init_single_subject_wf('01')
127+
flatgraph = wf._create_flat_graph()
128+
generate_expanded_graph(flatgraph)
129+
130+
seg_wf_name = f'pet_{config.workflow.seg}_seg_wf'
131+
seg_nodes = [n for n in wf.list_node_names() if n.startswith(seg_wf_name)]
132+
assert seg_nodes
133+
134+
pet_wf_names = [
135+
n
136+
for n in {name.split('.')[0] for name in wf.list_node_names() if name.startswith('pet_')}
137+
if n != seg_wf_name
138+
]
139+
assert len(pet_wf_names) == 2
140+
141+
seg_node = wf.get_node(seg_wf_name)
142+
for name in pet_wf_names:
143+
pet_node = wf.get_node(name)
144+
edge = wf._graph.get_edge_data(seg_node, pet_node)
145+
assert ('outputnode.segmentation', 'inputnode.segmentation') in edge['connect']
146+
assert ('outputnode.dseg_tsv', 'inputnode.dseg_tsv') in edge['connect']
147+
assert all('_seg_wf' not in n for n in pet_node.list_node_names())
148+
149+
99150
def _make_params(
100151
pet2anat_init: str = 'auto',
101152
medial_surface_nan: bool = False,

0 commit comments

Comments
 (0)