Skip to content

Commit 80392f9

Browse files
authored
Merge pull request #382 from mgxd/resample/cifti-2mm
FIX: Select reference resolution based on grayordinates
2 parents e7c83cf + d4264e4 commit 80392f9

File tree

3 files changed

+41
-29
lines changed

3 files changed

+41
-29
lines changed

nibabies/tests/test_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,23 @@ def test_config_spaces():
104104
] == ['MNIInfant_cohort-1_res-native', 'MNI152NLin6Asym_res-2', 'MNIInfant_cohort-1_res-2']
105105
_reset_config()
106106

107+
config.execution.output_spaces = None
108+
config.workflow.cifti_output = '170k'
109+
spaces = _load_spaces(1)
110+
111+
assert [str(s) for s in spaces.get_standard(full_spec=True)] == [
112+
'MNIInfant:cohort-1:res-native', # Default output space
113+
'MNI152NLin6Asym:res-1',
114+
'MNIInfant:cohort-1:res-1',
115+
]
116+
117+
assert [
118+
format_reference((s.fullname, s.spec))
119+
for s in spaces.references
120+
if s.standard and s.dim == 3
121+
] == ['MNIInfant_cohort-1_res-native', 'MNI152NLin6Asym_res-1', 'MNIInfant_cohort-1_res-1']
122+
_reset_config()
123+
107124

108125
@pytest.mark.parametrize(
109126
('master_seed', 'ants_seed', 'numpy_seed'), [(1, 17612, 8272), (100, 19094, 60232)]

nibabies/workflows/base.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,12 @@ def init_single_subject_wf(
515515
]) # fmt:skip
516516

517517
if cifti_output and 'MNIInfant' in [ref.space for ref in spaces.references]:
518+
mniinfant_res = 2 if config.workflow.cifti_output == '91k' else 1
519+
518520
select_MNIInfant_xfm = pe.Node(
519521
KeySelect(
520522
fields=['anat2std_xfm', 'std2anat_xfm'],
521-
key=get_MNIInfant_key(spaces),
523+
key=get_MNIInfant_key(spaces, mniinfant_res),
522524
),
523525
name='select_MNIInfant_xfm',
524526
run_without_submitting=True,
@@ -840,7 +842,7 @@ def init_workflow_spaces(execution_spaces: SpatialReferences, age_months: int):
840842
spaces.add(Reference('MNI152NLin6Asym', {'res': vol_res}))
841843
# Ensure a non-native version of MNIInfant is added as a target
842844
cohort = cohort_by_months('MNIInfant', age_months)
843-
spaces.add(Reference('MNIInfant', {'cohort': cohort, 'res': 2}))
845+
spaces.add(Reference('MNIInfant', {'cohort': cohort, 'res': vol_res}))
844846

845847
return spaces
846848

@@ -950,15 +952,10 @@ def get_estimator(layout, fname):
950952
return field_source
951953

952954

953-
def get_MNIInfant_key(spaces: SpatialReferences) -> str:
955+
def get_MNIInfant_key(spaces: SpatialReferences, res: str | int) -> str:
954956
"""Parse spaces and return matching MNIInfant space, including cohort."""
955-
key = None
956-
for space in spaces.references:
957-
# str formats as <reference.name>:<reference.spec>
958-
if 'MNIInfant' in str(space) and 'res-2' in str(space):
959-
key = str(space)
960-
break
961-
962-
if key is None:
963-
raise KeyError(f'MNIInfant (resolution 2x2x2) not found in SpatialReferences: {spaces}')
964-
return key
957+
for ref in spaces.references:
958+
if ref.space == 'MNIInfant' and f'res-{res}' in str(ref):
959+
return ref.fullname
960+
961+
raise KeyError(f'MNIInfant (resolution {res}) not found in SpatialReferences: {spaces}')

nibabies/workflows/bold/base.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,8 @@ def init_bold_wf(
558558
]),
559559
]) # fmt:skip
560560

561-
if config.workflow.cifti_output:
561+
cifti_output = config.workflow.cifti_output
562+
if cifti_output:
562563
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
563564

564565
from nibabies.workflows.bold.alignment import (
@@ -581,7 +582,7 @@ def init_bold_wf(
581582
)
582583

583584
bold_fsLR_resampling_wf = init_bold_fsLR_resampling_wf(
584-
grayord_density=config.workflow.cifti_output,
585+
grayord_density=cifti_output,
585586
omp_nthreads=omp_nthreads,
586587
mem_gb=mem_gb['resampled'],
587588
)
@@ -615,7 +616,7 @@ def init_bold_wf(
615616
subcortical_mni_alignment_wf = init_subcortical_mni_alignment_wf()
616617

617618
bold_grayords_wf = init_bold_grayords_wf(
618-
grayord_density=config.workflow.cifti_output,
619+
grayord_density=cifti_output,
619620
repetition_time=all_metadata[0]['RepetitionTime'],
620621
)
621622

@@ -624,7 +625,7 @@ def init_bold_wf(
624625
base_directory=output_dir,
625626
dismiss_entities=DEFAULT_DISMISS_ENTITIES,
626627
space='fsLR',
627-
density=config.workflow.cifti_output,
628+
density=cifti_output,
628629
suffix='bold',
629630
compress=False,
630631
TaskName=all_metadata[0].get('TaskName'),
@@ -635,7 +636,8 @@ def init_bold_wf(
635636
)
636637
ds_bold_cifti.inputs.source_file = bold_file
637638

638-
inputnode.inputs.mniinfant_mask = get_MNIInfant_mask(spaces)
639+
mniinfant_res = 2 if cifti_output == '91k' else 1
640+
inputnode.inputs.mniinfant_mask = get_MNIInfant_mask(spaces, mniinfant_res)
639641

640642
workflow.connect([
641643
# Resample BOLD to MNI152NLin6Asym, may duplicate bold_std_wf above
@@ -747,11 +749,11 @@ def init_bold_wf(
747749
]) # fmt:skip
748750

749751
# MG: Carpetplot workflow only work with CIFTI
750-
if config.workflow.cifti_output:
752+
if cifti_output:
751753
carpetplot_wf = init_carpetplot_wf(
752754
mem_gb=mem_gb['resampled'],
753755
metadata=all_metadata[0],
754-
cifti_output=config.workflow.cifti_output,
756+
cifti_output=cifti_output,
755757
name='carpetplot_wf',
756758
)
757759

@@ -847,24 +849,20 @@ def _read_json(in_file):
847849
return loads(Path(in_file).read_text())
848850

849851

850-
def get_MNIInfant_mask(spaces: 'SpatialReferences') -> str:
852+
def get_MNIInfant_mask(spaces: 'SpatialReferences', res: str | int) -> str:
851853
"""Parse spaces and return matching MNIInfant space, including cohort."""
852854
import templateflow.api as tf
853855

854-
mask = None
855856
for ref in spaces.references:
856-
# str formats as <reference.name>:<reference.spec>
857-
if ref.space == 'MNIInfant' and ref.spec.get('res', '') != 'native':
858-
mask = str(
857+
if ref.space == 'MNIInfant' and f'res-{res}' in str(ref):
858+
return str(
859859
tf.get(
860860
'MNIInfant',
861861
cohort=ref.spec['cohort'],
862-
resolution=1,
862+
resolution=res,
863863
desc='brain',
864864
suffix='mask',
865865
)
866866
)
867867

868-
if mask is None:
869-
raise FileNotFoundError('MNIInfant brain mask not found.')
870-
return mask
868+
raise FileNotFoundError(f'MNIInfant mask (resolution {res}) not found.')

0 commit comments

Comments
 (0)