diff --git a/pyproject.toml b/pyproject.toml index 37ab4b2241..aaef895ccf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,7 +193,6 @@ ignore = [ "B019", "SIM108", "C901", - "UP038", ] [tool.ruff.lint.flake8-quotes] diff --git a/sdcflows/workflows/base.py b/sdcflows/workflows/base.py index 88c5afd267..40f938b6a6 100644 --- a/sdcflows/workflows/base.py +++ b/sdcflows/workflows/base.py @@ -137,6 +137,7 @@ def init_fmap_preproc_wf( fmap_derivatives_wf = init_fmap_derivatives_wf( output_dir=str(output_dir), write_coeff=True, + write_mask=True, bids_fmap_id=estimator.bids_id, name=f'fmap_derivatives_wf_{estimator.sanitized_id}', ) @@ -157,18 +158,16 @@ def init_fmap_preproc_wf( niu.IdentityInterface(fields=fields), name=f'in_{estimator.sanitized_id}', ) - # fmt:off workflow.connect([ (inputnode, est_wf, [(f, f"inputnode.{f}") for f in fields]) - ]) - # fmt:on + ]) # fmt:skip - # fmt:off workflow.connect([ (est_wf, fmap_derivatives_wf, [ ("outputnode.fmap", "inputnode.fieldmap"), ("outputnode.fmap_ref", "inputnode.fmap_ref"), ("outputnode.fmap_coeff", "inputnode.fmap_coeff"), + ("outputnode.fmap_mask", "inputnode.fmap_mask"), ]), (est_wf, fmap_reports_wf, [ ("outputnode.fmap", "inputnode.fieldmap"), @@ -176,14 +175,15 @@ def init_fmap_preproc_wf( ("outputnode.fmap_mask", "inputnode.fmap_mask"), ]), (est_wf, out_map, [ - ("outputnode.fmap", "fmap"), + ("outputnode.method", "method") + ]), + (fmap_derivatives_wf, out_map, [ + ("outputnode.fieldmap", "fmap"), ("outputnode.fmap_ref", "fmap_ref"), ("outputnode.fmap_coeff", "fmap_coeff"), ("outputnode.fmap_mask", "fmap_mask"), - ("outputnode.method", "method") ]), - ]) - # fmt:on + ]) # fmt:skip for field, mergenode in out_merge.items(): workflow.connect(out_map, field, mergenode, f'in{n}') diff --git a/sdcflows/workflows/outputs.py b/sdcflows/workflows/outputs.py index ce5aa09997..a7913d4a0c 100644 --- a/sdcflows/workflows/outputs.py +++ b/sdcflows/workflows/outputs.py @@ -98,14 +98,12 @@ def init_fmap_reports_wf( suffix='fieldmap', desc=fmap_type, dismiss_entities=('fmap',), - allowed_entities=tuple(custom_entities.keys()), + allowed_entities=tuple(custom_entities), ), name='ds_fmap_report', ) - for k, v in custom_entities.items(): - setattr(ds_fmap_report.inputs, k, v) + ds_fmap_report.inputs.trait_set(**custom_entities) - # fmt:off workflow.connect([ (inputnode, fmap_rpt, [(("fieldmap", _pop), "fieldmap"), ("fmap_ref", "reference"), @@ -113,8 +111,7 @@ def init_fmap_reports_wf( (fmap_rpt, ds_fmap_report, [("out_report", "in_file")]), (inputnode, ds_fmap_report, [("source_files", "source_file")]), - ]) - # fmt:on + ]) # fmt:skip return workflow @@ -126,6 +123,7 @@ def init_fmap_derivatives_wf( custom_entities=None, name='fmap_derivatives_wf', write_coeff=False, + write_mask=False, ): """ Set up datasinks to store derivatives in the right location. @@ -162,10 +160,14 @@ def init_fmap_derivatives_wf( workflow = pe.Workflow(name=name) inputnode = pe.Node( niu.IdentityInterface( - fields=['source_files', 'fieldmap', 'fmap_coeff', 'fmap_ref', 'fmap_meta'] + fields=['source_files', 'fieldmap', 'fmap_coeff', 'fmap_ref', 'fmap_mask', 'fmap_meta'] ), name='inputnode', ) + outputnode = pe.Node( + niu.IdentityInterface(fields=['fieldmap', 'fmap_coeff', 'fmap_ref', 'fmap_mask']), + name='outputnode', + ) merge_fmap = pe.Node(MergeSeries(), name='merge_fmap') @@ -176,7 +178,7 @@ def init_fmap_derivatives_wf( suffix='fieldmap', datatype='fmap', dismiss_entities=('fmap',), - allowed_entities=tuple(custom_entities.keys()), + allowed_entities=tuple(custom_entities), ), name='ds_reference', ) @@ -188,7 +190,7 @@ def init_fmap_derivatives_wf( suffix='fieldmap', datatype='fmap', compress=True, - allowed_entities=tuple(custom_entities.keys()), + allowed_entities=tuple(custom_entities), ), name='ds_fieldmap', ) @@ -196,11 +198,9 @@ def init_fmap_derivatives_wf( if bids_fmap_id: ds_fieldmap.inputs.B0FieldIdentifier = bids_fmap_id - for k, v in custom_entities.items(): - setattr(ds_reference.inputs, k, v) - setattr(ds_fieldmap.inputs, k, v) + ds_reference.inputs.trait_set(**custom_entities) + ds_fieldmap.inputs.trait_set(**custom_entities) - # fmt:off workflow.connect([ (inputnode, merge_fmap, [("fieldmap", "in_files")]), (inputnode, ds_reference, [("source_files", "source_file"), @@ -213,8 +213,38 @@ def init_fmap_derivatives_wf( (("out_file", _getname), "AnatomicalReference"), ]), (inputnode, ds_fieldmap, [(("fmap_meta", _selectintent), "IntendedFor")]), - ]) - # fmt:on + (ds_fieldmap, outputnode, [("out_file", "fieldmap")]), + (ds_reference, outputnode, [("out_file", "fmap_ref")]), + ]) # fmt:skip + + if write_mask: + ds_mask = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + compress=True, + desc='brain', + suffix='mask', + datatype='fmap', + dismiss_entities=('fmap',), + allowed_entities=tuple(custom_entities), + ), + name='ds_mask', + ) + ds_mask._interface._file_patterns += ( + 'sub-{subject}[/ses-{session}]/{datatype|fmap}/' + 'sub-{subject}[_ses-{session}][_hash-{hash}][_acq-{acquisition}]' + '[_dir-{direction}][_run-{run}][_part-{part}][_space-{space}]' + '[_cohort-{cohort}][_res-{resolution}][_fmapid-{fmapid}]' + '[_desc-{desc}]_{suffix}{extension<.nii|.nii.gz|.json>|.nii.gz}', + ) + + ds_mask.inputs.trait_set(**custom_entities) + + workflow.connect([ + (inputnode, ds_mask, [("source_files", "source_file"), + ("fmap_mask", "in_file")]), + (ds_mask, outputnode, [("out_file", "fmap_mask")]), + ]) # fmt:skip if not write_coeff: return workflow @@ -225,7 +255,7 @@ def init_fmap_derivatives_wf( suffix='fieldmap', datatype='fmap', compress=True, - allowed_entities=tuple(custom_entities.keys()), + allowed_entities=tuple(custom_entities), ), name='ds_coeff', iterfield=('in_file', 'desc'), @@ -233,18 +263,16 @@ def init_fmap_derivatives_wf( gen_desc = pe.Node(niu.Function(function=_gendesc), name='gen_desc') - for k, v in custom_entities.items(): - setattr(ds_coeff.inputs, k, v) + ds_coeff.inputs.trait_set(**custom_entities) - # fmt:off workflow.connect([ (inputnode, ds_coeff, [("source_files", "source_file"), ("fmap_coeff", "in_file")]), (inputnode, gen_desc, [("fmap_coeff", "infiles")]), (gen_desc, ds_coeff, [("out", "desc")]), (ds_coeff, ds_fieldmap, [(("out_file", _getname), "AssociatedCoefficients")]), - ]) - # fmt:on + (ds_coeff, outputnode, [("out_file", "fmap_coeff")]), + ]) # fmt:skip return workflow