Skip to content

Commit 544642f

Browse files
authored
Merge pull request #113 from nipreps/refactor_segmentation
ENH: Refactor PET workflow segmentation process
2 parents 9d12689 + 2129b21 commit 544642f

File tree

5 files changed

+99
-51
lines changed

5 files changed

+99
-51
lines changed

petprep/workflows/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def init_single_subject_wf(subject_id: str):
164164
)
165165

166166
from petprep.workflows.pet.base import init_pet_wf
167+
from petprep.workflows.pet.segmentation import init_segmentation_wf
167168

168169
workflow = Workflow(name=f'sub_{subject_id}_wf')
169170
workflow.__desc__ = f"""
@@ -523,6 +524,24 @@ def init_single_subject_wf(subject_id: str):
523524
]),
524525
]) # fmt:skip
525526

527+
segmentation_wf = init_segmentation_wf(
528+
seg=config.workflow.seg,
529+
name=f'pet_{config.workflow.seg}_seg_wf',
530+
)
531+
workflow.connect(
532+
[
533+
(
534+
anat_fit_wf,
535+
segmentation_wf,
536+
[
537+
('outputnode.t1w_preproc', 'inputnode.t1w_preproc'),
538+
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
539+
('outputnode.subject_id', 'inputnode.subject_id'),
540+
],
541+
),
542+
]
543+
)
544+
526545
if config.workflow.anat_only:
527546
return clean_datasinks(workflow)
528547

@@ -600,6 +619,10 @@ def init_single_subject_wf(subject_id: str):
600619
'inputnode.sphere_reg_fsLR',
601620
),
602621
]),
622+
(segmentation_wf, pet_wf, [
623+
('outputnode.segmentation', 'inputnode.segmentation'),
624+
('outputnode.dseg_tsv', 'inputnode.dseg_tsv'),
625+
]),
603626
]) # fmt:skip
604627

605628
if config.workflow.level == 'full':

petprep/workflows/pet/base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def init_pet_wf(
115115
Registration spheres from fsnative to fsLR space, collated left, then right
116116
anat_ribbon
117117
Binary cortical ribbon mask in T1w space
118+
segmentation
119+
Segmentation file in T1w space
120+
dseg_tsv
121+
TSV with segmentation statistics
118122
anat2std_xfm
119123
Transform from anatomical space to standard space
120124
std_t1w
@@ -206,6 +210,8 @@ def init_pet_wf(
206210
'midthickness_fsLR',
207211
'cortex_mask',
208212
'anat_ribbon',
213+
'segmentation',
214+
'dseg_tsv',
209215
# Volumetric templates
210216
'anat2std_xfm',
211217
'std_t1w',
@@ -242,6 +248,8 @@ def init_pet_wf(
242248
('subjects_dir', 'inputnode.subjects_dir'),
243249
('subject_id', 'inputnode.subject_id'),
244250
('fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
251+
('segmentation', 'inputnode.segmentation'),
252+
('dseg_tsv', 'inputnode.dseg_tsv'),
245253
]),
246254
]) # fmt:skip
247255

@@ -391,9 +399,7 @@ def init_pet_wf(
391399
('t1w_tpms', 'inputnode.t1w_tpms'),
392400
('subjects_dir', 'inputnode.subjects_dir'),
393401
('subject_id', 'inputnode.subject_id'),
394-
]),
395-
(pet_fit_wf, pet_pvc_wf, [
396-
('outputnode.segmentation', 'inputnode.segmentation'),
402+
('segmentation', 'inputnode.segmentation'),
397403
]),
398404
(petref_t1w, pet_pvc_wf, [('output_image', 'inputnode.petref')]),
399405
(pet_pvc_wf, psf_meta, [
@@ -711,9 +717,9 @@ def init_pet_wf(
711717

712718
workflow.connect([
713719
(pet_t1w_src, pet_tacs_wf, [(pet_t1w_field, 'inputnode.pet_anat')]),
714-
(pet_fit_wf, pet_tacs_wf, [
715-
('outputnode.segmentation', 'inputnode.segmentation'),
716-
('outputnode.dseg_tsv', 'inputnode.dseg_tsv'),
720+
(inputnode, pet_tacs_wf, [
721+
('segmentation', 'inputnode.segmentation'),
722+
('dseg_tsv', 'inputnode.dseg_tsv'),
717723
]),
718724
(pet_tacs_wf, ds_pet_tacs, [('outputnode.timeseries', 'in_file')]),
719725
]) # fmt:skip

petprep/workflows/pet/fit.py

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from .ref_tacs import init_pet_ref_tacs_wf
5252
from .reference_mask import init_pet_refmask_wf
5353
from .registration import init_pet_reg_wf
54-
from .segmentation import init_segmentation_wf
5554

5655

5756
def init_pet_fit_wf(
@@ -103,6 +102,10 @@ def init_pet_fit_wf(
103102
FreeSurfer subject ID
104103
fsnative2t1w_xfm
105104
LTA-style affine matrix translating from FreeSurfer-conformed subject space to T1w
105+
segmentation
106+
Segmentation file in T1w space
107+
dseg_tsv
108+
TSV with segmentation statistics
106109
107110
Outputs
108111
-------
@@ -169,6 +172,8 @@ def init_pet_fit_wf(
169172
'subjects_dir',
170173
'subject_id',
171174
'fsnative2t1w_xfm',
175+
'segmentation',
176+
'dseg_tsv',
172177
],
173178
),
174179
name='inputnode',
@@ -182,8 +187,6 @@ def init_pet_fit_wf(
182187
'pet_mask',
183188
'motion_xfm',
184189
'petref2anat_xfm',
185-
'segmentation',
186-
'dseg_tsv',
187190
'refmask',
188191
],
189192
),
@@ -397,41 +400,10 @@ def init_pet_fit_wf(
397400
ds_petmask_wf.inputs.inputnode.source_files = [pet_file]
398401
workflow.connect([(merge_mask, ds_petmask_wf, [('out', 'inputnode.petmask')])])
399402

400-
# Stage 4: Segmentation
401-
config.loggers.workflow.info(
402-
'PET Stage 4: Adding segmentation workflow using the segmentation: %s', config.workflow.seg
403-
)
404-
segmentation_wf = init_segmentation_wf(
405-
seg=config.workflow.seg,
406-
name=f'pet_{config.workflow.seg}_seg_wf',
407-
)
408-
409-
workflow.connect(
410-
[
411-
(
412-
inputnode,
413-
segmentation_wf,
414-
[
415-
('t1w_preproc', 'inputnode.t1w_preproc'),
416-
('subject_id', 'inputnode.subject_id'),
417-
('subjects_dir', 'inputnode.subjects_dir'),
418-
],
419-
),
420-
(
421-
segmentation_wf,
422-
outputnode,
423-
[
424-
('outputnode.segmentation', 'segmentation'),
425-
('outputnode.dseg_tsv', 'dseg_tsv'),
426-
],
427-
),
428-
]
429-
)
430-
431-
# Stage 5: Reference mask generation
403+
# Stage 4: Reference mask generation
432404
if config.workflow.ref_mask_name:
433405
config.loggers.workflow.info(
434-
'PET Stage 5: Generating %s reference mask',
406+
'PET Stage 4: Generating %s reference mask',
435407
config.workflow.ref_mask_name,
436408
)
437409

@@ -487,10 +459,10 @@ def init_pet_fit_wf(
487459
workflow.connect(
488460
[
489461
(
490-
segmentation_wf,
462+
inputnode,
491463
refmask_wf,
492464
[
493-
('outputnode.segmentation', 'inputnode.seg_file'),
465+
('segmentation', 'inputnode.seg_file'),
494466
],
495467
),
496468
(
@@ -572,7 +544,7 @@ def init_pet_fit_wf(
572544
]
573545
)
574546
else:
575-
config.loggers.workflow.info('PET Stage 5: Reference mask generation skipped')
547+
config.loggers.workflow.info('PET Stage 4: Reference mask generation skipped')
576548

577549
return workflow
578550

petprep/workflows/pet/tests/test_base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ def test_pvc_receives_segmentation(bids_root: Path):
205205

206206
wf = init_pet_wf(pet_series=pet_series, precomputed={})
207207

208-
edge = wf._graph.get_edge_data(wf.get_node('pet_fit_wf'), wf.get_node('pet_pvc_wf'))
209-
assert ('outputnode.segmentation', 'inputnode.segmentation') in edge['connect']
208+
edge = wf._graph.get_edge_data(wf.get_node('inputnode'), wf.get_node('pet_pvc_wf'))
209+
assert ('segmentation', 'inputnode.segmentation') in edge['connect']
210210

211211

212212
def test_pet_tacs_wf_connections(bids_root: Path):
@@ -226,9 +226,9 @@ def test_pet_tacs_wf_connections(bids_root: Path):
226226
edge_anat = wf._graph.get_edge_data(wf.get_node('pet_anat_wf'), wf.get_node('pet_tacs_wf'))
227227
assert ('outputnode.pet_file', 'inputnode.pet_anat') in edge_anat['connect']
228228

229-
edge_fit = wf._graph.get_edge_data(wf.get_node('pet_fit_wf'), wf.get_node('pet_tacs_wf'))
230-
assert ('outputnode.segmentation', 'inputnode.segmentation') in edge_fit['connect']
231-
assert ('outputnode.dseg_tsv', 'inputnode.dseg_tsv') in edge_fit['connect']
229+
edge_input = wf._graph.get_edge_data(wf.get_node('inputnode'), wf.get_node('pet_tacs_wf'))
230+
assert ('segmentation', 'inputnode.segmentation') in edge_input['connect']
231+
assert ('dseg_tsv', 'inputnode.dseg_tsv') in edge_input['connect']
232232

233233
edge_ds = wf._graph.get_edge_data(wf.get_node('pet_tacs_wf'), wf.get_node('ds_pet_tacs'))
234234
assert ('outputnode.timeseries', 'in_file') in edge_ds['connect']

petprep/workflows/tests/test_base.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from niworkflows.utils.testing import generate_bids_skeleton
1414

1515
from ... import config
16-
from ..base import init_petprep_wf
16+
from ..base import init_petprep_wf, init_single_subject_wf
1717
from ..tests import mock_config
1818

1919
BASE_LAYOUT = {
@@ -96,6 +96,53 @@ def bids_root(tmp_path_factory):
9696
return bids_dir
9797

9898

99+
@pytest.fixture(scope='module')
100+
def multisession_bids_root(tmp_path_factory):
101+
base = tmp_path_factory.mktemp('multisession')
102+
bids_dir = base / 'bids'
103+
bids_dir.mkdir(parents=True, exist_ok=True)
104+
img = nb.Nifti1Image(np.zeros((10, 10, 10, 10)), np.eye(4))
105+
(bids_dir / 'dataset_description.json').write_text('{"Name": "Test", "BIDSVersion": "1.8.0"}')
106+
for ses in ['01', '02']:
107+
anat_dir = bids_dir / 'sub-01' / f'ses-{ses}' / 'anat'
108+
pet_dir = bids_dir / 'sub-01' / f'ses-{ses}' / 'pet'
109+
anat_dir.mkdir(parents=True, exist_ok=True)
110+
pet_dir.mkdir(parents=True, exist_ok=True)
111+
img.to_filename(anat_dir / f'sub-01_ses-{ses}_T1w.nii.gz')
112+
pet_path = pet_dir / f'sub-01_ses-{ses}_task-rest_run-1_pet.nii.gz'
113+
img.to_filename(pet_path)
114+
(pet_path.with_suffix('').with_suffix('.json')).write_text(
115+
'{"FrameTimesStart": [0], "FrameDuration": [1]}'
116+
)
117+
return bids_dir
118+
119+
120+
def test_segmentation_shared_across_runs(multisession_bids_root):
121+
with mock_config(bids_dir=multisession_bids_root):
122+
wf = init_single_subject_wf('01')
123+
flatgraph = wf._create_flat_graph()
124+
generate_expanded_graph(flatgraph)
125+
126+
seg_wf_name = f'pet_{config.workflow.seg}_seg_wf'
127+
seg_nodes = [n for n in wf.list_node_names() if n.startswith(seg_wf_name)]
128+
assert seg_nodes
129+
130+
pet_wf_names = [
131+
n
132+
for n in {name.split('.')[0] for name in wf.list_node_names() if name.startswith('pet_')}
133+
if n != seg_wf_name
134+
]
135+
assert len(pet_wf_names) == 2
136+
137+
seg_node = wf.get_node(seg_wf_name)
138+
for name in pet_wf_names:
139+
pet_node = wf.get_node(name)
140+
edge = wf._graph.get_edge_data(seg_node, pet_node)
141+
assert ('outputnode.segmentation', 'inputnode.segmentation') in edge['connect']
142+
assert ('outputnode.dseg_tsv', 'inputnode.dseg_tsv') in edge['connect']
143+
assert all('_seg_wf' not in n for n in pet_node.list_node_names())
144+
145+
99146
def _make_params(
100147
pet2anat_init: str = 'auto',
101148
medial_surface_nan: bool = False,

0 commit comments

Comments
 (0)