Skip to content

Commit 6512180

Browse files
committed
feat: Add flag to set TRT fallback behavior
1 parent 7e65722 commit 6512180

File tree

6 files changed

+43
-2
lines changed

6 files changed

+43
-2
lines changed

fmriprep/cli/parser.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ def _slice_time_ref(value, parser):
149149
raise parser.error(f'Slice time reference must be in range 0-1. Received {value}.')
150150
return value
151151

152+
def _fallback_trt(value, parser):
153+
if value == 'estimated':
154+
return value
155+
try:
156+
return float(value)
157+
except ValueError:
158+
raise parser.error(
159+
f'Falling back to TRT must be a number or "estimated". Received {value}.'
160+
) from None
161+
152162
verstr = f'fMRIPrep v{config.environment.version}'
153163
currentv = Version(config.environment.version)
154164
is_release = not any((currentv.is_devrelease, currentv.is_prerelease, currentv.is_postrelease))
@@ -163,6 +173,7 @@ def _slice_time_ref(value, parser):
163173
PositiveInt = partial(_min_one, parser=parser)
164174
BIDSFilter = partial(_bids_filter, parser=parser)
165175
SliceTimeRef = partial(_slice_time_ref, parser=parser)
176+
FallbackTRT = partial(_fallback_trt, parser=parser)
166177

167178
# Arguments as specified by BIDS-Apps
168179
# required, positional arguments
@@ -417,6 +428,15 @@ def _slice_time_ref(value, parser):
417428
type=int,
418429
help='Number of nonsteady-state volumes. Overrides automatic detection.',
419430
)
431+
g_conf.add_argument(
432+
'--fallback-total-readout-time',
433+
required=False,
434+
action='store',
435+
default=None,
436+
type=FallbackTRT,
437+
help='Fallback value for Total Readout Time (TRT) calculation. '
438+
'May be a number or "estimated".',
439+
)
420440
g_conf.add_argument(
421441
'--random-seed',
422442
dest='_random_seed',

fmriprep/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,9 @@ class workflow(_Config):
575575
"""Remove the mean from fieldmaps."""
576576
force_syn = None
577577
"""Run *fieldmap-less* susceptibility-derived distortions estimation."""
578+
fallback_total_readout_time = None
579+
"""Infer the total readout time if unavailable from authoritative metadata.
580+
This may be a number or the string "estimated"."""
578581
hires = None
579582
"""Run FreeSurfer ``recon-all`` with the ``-hires`` flag."""
580583
fs_no_resume = None

fmriprep/interfaces/resampling.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ def _run_interface(self, runtime):
191191
class DistortionParametersInputSpec(TraitedSpec):
192192
in_file = File(exists=True, desc='EPI image corresponding to the metadata')
193193
metadata = traits.Dict(mandatory=True, desc='metadata corresponding to the inputs')
194+
fallback = traits.Either(
195+
None,
196+
'estimated',
197+
traits.Float,
198+
usedefault=True,
199+
desc='Fallback value for missing metadata',
200+
)
194201

195202

196203
class DistortionParametersOutputSpec(TraitedSpec):
@@ -215,6 +222,8 @@ def _run_interface(self, runtime):
215222
self._results['readout_time'] = get_trt(
216223
self.inputs.metadata,
217224
self.inputs.in_file or None,
225+
use_estimate=self.inputs.fallback == 'estimated',
226+
fallback=self.inputs.fallback if isinstance(self.inputs.fallback, float) else None,
218227
)
219228
self._results['pe_direction'] = self.inputs.metadata['PhaseEncodingDirection']
220229
except (KeyError, ValueError):

fmriprep/workflows/bold/apply.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def init_bold_volumetric_resample_wf(
1717
metadata: dict,
1818
mem_gb: dict[str, float],
1919
jacobian: bool,
20+
fallback_total_readout_time: str | float | None = None,
2021
fieldmap_id: str | None = None,
2122
omp_nthreads: int = 1,
2223
name: str = 'bold_volumetric_resample_wf',
@@ -161,7 +162,10 @@ def init_bold_volumetric_resample_wf(
161162
run_without_submitting=True,
162163
)
163164
distortion_params = pe.Node(
164-
DistortionParameters(metadata=metadata),
165+
DistortionParameters(
166+
metadata=metadata,
167+
fallback=fallback_total_readout_time,
168+
),
165169
name='distortion_params',
166170
run_without_submitting=True,
167171
)

fmriprep/workflows/bold/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def init_bold_wf(
383383
# Resample to anatomical space
384384
bold_anat_wf = init_bold_volumetric_resample_wf(
385385
metadata=all_metadata[0],
386+
fallback_total_readout_time=config.workflow.fallback_total_readout_time,
386387
fieldmap_id=fieldmap_id if not multiecho else None,
387388
omp_nthreads=omp_nthreads,
388389
mem_gb=mem_gb,

fmriprep/workflows/bold/fit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,11 @@ def init_bold_native_wf(
859859
)
860860

861861
distortion_params = pe.Node(
862-
DistortionParameters(metadata=metadata, in_file=bold_file),
862+
DistortionParameters(
863+
metadata=metadata,
864+
in_file=bold_file,
865+
fallback=config.workflow.fallback_total_readout_time,
866+
),
863867
name='distortion_params',
864868
run_without_submitting=True,
865869
)

0 commit comments

Comments
 (0)