Skip to content

Commit e2c0fc0

Browse files
authored
FIX: Tag memory estimates in resamplers (#3150)
We'll start with the baseline assumption that our resampler uses about 4*the original BOLD series size, which was generally a good estimate for antsApplyTransform. Also tagging a few things with `run_without_submitting` and some tasks that are surely using more than the default amount with 1GB, which is at least better. STC I assume uses 2x the total amount. All of this could stand profiling, and I'm curious to try out [memray](https://bloomberg.github.io/memray/) on the Python stuff at least.
1 parent d9e92a9 commit e2c0fc0

File tree

5 files changed

+39
-11
lines changed

5 files changed

+39
-11
lines changed

fmriprep/workflows/bold/apply.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
def init_bold_volumetric_resample_wf(
1616
*,
1717
metadata: dict,
18+
mem_gb: dict[str, float],
1819
fieldmap_id: str | None = None,
1920
omp_nthreads: int = 1,
2021
name: str = 'bold_volumetric_resample_wf',
@@ -119,9 +120,14 @@ def init_bold_volumetric_resample_wf(
119120

120121
gen_ref = pe.Node(GenerateSamplingReference(), name='gen_ref', mem_gb=0.3)
121122

122-
boldref2target = pe.Node(niu.Merge(2), name='boldref2target')
123-
bold2target = pe.Node(niu.Merge(2), name='bold2target')
124-
resample = pe.Node(ResampleSeries(), name="resample", n_procs=omp_nthreads)
123+
boldref2target = pe.Node(niu.Merge(2), name='boldref2target', run_without_submitting=True)
124+
bold2target = pe.Node(niu.Merge(2), name='bold2target', run_without_submitting=True)
125+
resample = pe.Node(
126+
ResampleSeries(),
127+
name="resample",
128+
n_procs=omp_nthreads,
129+
mem_gb=mem_gb['resampled'],
130+
)
125131

126132
workflow.connect([
127133
(inputnode, gen_ref, [
@@ -156,10 +162,14 @@ def init_bold_volumetric_resample_wf(
156162
name="distortion_params",
157163
run_without_submitting=True,
158164
)
159-
fmap2target = pe.Node(niu.Merge(2), name='fmap2target')
160-
inverses = pe.Node(niu.Function(function=_gen_inverses), name='inverses')
165+
fmap2target = pe.Node(niu.Merge(2), name='fmap2target', run_without_submitting=True)
166+
inverses = pe.Node(
167+
niu.Function(function=_gen_inverses),
168+
name='inverses',
169+
run_without_submitting=True,
170+
)
161171

162-
fmap_recon = pe.Node(ReconstructFieldmap(), name="fmap_recon")
172+
fmap_recon = pe.Node(ReconstructFieldmap(), name="fmap_recon", mem_gb=1)
163173

164174
workflow.connect([
165175
(inputnode, fmap_select, [

fmriprep/workflows/bold/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def init_bold_wf(
313313
metadata=all_metadata[0],
314314
fieldmap_id=fieldmap_id if not multiecho else None,
315315
omp_nthreads=omp_nthreads,
316+
mem_gb=mem_gb,
316317
name='bold_anat_wf',
317318
)
318319
bold_anat_wf.inputs.inputnode.resolution = "native"
@@ -446,6 +447,7 @@ def init_bold_wf(
446447
metadata=all_metadata[0],
447448
fieldmap_id=fieldmap_id if not multiecho else None,
448449
omp_nthreads=omp_nthreads,
450+
mem_gb=mem_gb,
449451
name='bold_std_wf',
450452
)
451453
ds_bold_std_wf = init_ds_volumes_wf(
@@ -525,6 +527,7 @@ def init_bold_wf(
525527
metadata=all_metadata[0],
526528
fieldmap_id=fieldmap_id if not multiecho else None,
527529
omp_nthreads=omp_nthreads,
530+
mem_gb=mem_gb,
528531
name='bold_MNI6_wf',
529532
)
530533

@@ -537,7 +540,7 @@ def init_bold_wf(
537540

538541
bold_grayords_wf = init_bold_grayords_wf(
539542
grayord_density=config.workflow.cifti_output,
540-
mem_gb=mem_gb["resampled"],
543+
mem_gb=1,
541544
repetition_time=all_metadata[0]["RepetitionTime"],
542545
)
543546

fmriprep/workflows/bold/fit.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def init_bold_native_wf(
789789

790790
# Slice-timing correction
791791
if run_stc:
792-
bold_stc_wf = init_bold_stc_wf(name="bold_stc_wf", metadata=metadata)
792+
bold_stc_wf = init_bold_stc_wf(metadata=metadata, mem_gb=mem_gb)
793793
workflow.connect([
794794
(inputnode, bold_stc_wf, [("dummy_scans", "inputnode.skip_vols")]),
795795
(validate_bold, bold_stc_wf, [("out_file", "inputnode.bold_file")]),
@@ -824,7 +824,12 @@ def init_bold_native_wf(
824824
]) # fmt:skip
825825

826826
# Resample to boldref
827-
boldref_bold = pe.Node(ResampleSeries(), name="boldref_bold", n_procs=omp_nthreads)
827+
boldref_bold = pe.Node(
828+
ResampleSeries(),
829+
name="boldref_bold",
830+
n_procs=omp_nthreads,
831+
mem_gb=mem_gb["resampled"],
832+
)
828833

829834
workflow.connect([
830835
(inputnode, boldref_bold, [
@@ -839,7 +844,7 @@ def init_bold_native_wf(
839844
]) # fmt:skip
840845

841846
if fieldmap_id:
842-
boldref_fmap = pe.Node(ReconstructFieldmap(inverse=[True]), name="boldref_fmap")
847+
boldref_fmap = pe.Node(ReconstructFieldmap(inverse=[True]), name="boldref_fmap", mem_gb=1)
843848
workflow.connect([
844849
(inputnode, boldref_fmap, [
845850
("boldref", "target_ref_file"),
@@ -858,6 +863,7 @@ def init_bold_native_wf(
858863
joinsource="echo_index",
859864
joinfield=["bold_files"],
860865
name="join_echos",
866+
run_without_submitting=True,
861867
)
862868

863869
# create optimal combination, adaptive T2* map

fmriprep/workflows/bold/resampling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,12 +661,14 @@ def init_bold_fsLR_resampling_wf(
661661
metric_dilate = pe.Node(
662662
MetricDilate(distance=10, nearest=True),
663663
name="metric_dilate",
664+
mem_gb=1,
664665
n_procs=omp_nthreads,
665666
)
666667
mask_native = pe.Node(MetricMask(), name="mask_native")
667668
resample_to_fsLR = pe.Node(
668669
MetricResample(method='ADAP_BARY_AREA', area_surfs=True),
669670
name="resample_to_fsLR",
671+
mem_gb=1,
670672
n_procs=omp_nthreads,
671673
)
672674
# ... line 89
@@ -812,6 +814,7 @@ def init_bold_grayords_wf(
812814
grayordinates=grayord_density,
813815
),
814816
name="gen_cifti",
817+
mem_gb=mem_gb,
815818
)
816819

817820
workflow.connect([

fmriprep/workflows/bold/stc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ def _pre_run_hook(self, runtime):
5353
return runtime
5454

5555

56-
def init_bold_stc_wf(metadata: dict, name='bold_stc_wf'):
56+
def init_bold_stc_wf(
57+
*,
58+
mem_gb: dict,
59+
metadata: dict,
60+
name='bold_stc_wf',
61+
):
5762
"""
5863
Create a workflow for :abbr:`STC (slice-timing correction)`.
5964
@@ -119,6 +124,7 @@ def init_bold_stc_wf(metadata: dict, name='bold_stc_wf'):
119124
slice_encoding_direction=metadata.get('SliceEncodingDirection', 'k'),
120125
tzero=tzero,
121126
),
127+
mem_gb=mem_gb['filesize'] * 2,
122128
name='slice_timing_correction',
123129
)
124130

0 commit comments

Comments
 (0)