Skip to content

Commit 56ab528

Browse files
authored
Merge pull request #141 from nipreps/add_fs_seed
Add FreeSurfer seed to segmentation workflows
2 parents 6ced197 + 8316eea commit 56ab528

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

petprep/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,8 @@ class seeds(_Config):
671671
"""Seed used for antsRegistration, antsAI, antsMotionCorr"""
672672
numpy = None
673673
"""Seed used by NumPy"""
674+
freesurfer = None
675+
"""Seed used by FreeSurfer utilities"""
674676

675677
@classmethod
676678
def init(cls):
@@ -682,6 +684,7 @@ def init(cls):
682684
# functions to set program specific seeds
683685
cls.ants = _set_ants_seed()
684686
cls.numpy = _set_numpy_seed()
687+
cls.freesurfer = _set_freesurfer_seed()
685688

686689

687690
def _set_ants_seed():
@@ -700,6 +703,13 @@ def _set_numpy_seed():
700703
return val
701704

702705

706+
def _set_freesurfer_seed():
707+
"""Fix random seed for FreeSurfer utilities"""
708+
val = random.randint(1, 65536)
709+
os.environ['FREESURFER_RANDOM_SEED'] = str(val)
710+
return val
711+
712+
703713
def from_dict(settings, init=True, ignore=None):
704714
"""Read settings from a flat dictionary.
705715

petprep/interfaces/segmentation.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727
from nipype.interfaces.freesurfer.petsurfer import GTMSeg
2828
from nipype.utils.filemanip import fname_presuffix
2929

30+
from .. import config
31+
32+
33+
def _set_freesurfer_seed(runtime):
34+
runtime.environ['FREESURFER_RANDOM_SEED'] = str(config.seeds.freesurfer)
35+
return runtime
36+
3037

3138
class SegmentBSInputSpec(BaseInterfaceInputSpec):
3239
subjects_dir = Directory(exists=True, mandatory=True, desc='FreeSurfer subjects directory')
@@ -143,6 +150,7 @@ class MRISclimbicSeg(CommandLine):
143150
output_spec = MRISclimbicSegOutputSpec
144151

145152
def _run_interface(self, runtime):
153+
_set_freesurfer_seed(runtime)
146154
outputs = self._list_outputs()
147155
expected = [outputs['out_file'], outputs['out_stats']]
148156

@@ -186,6 +194,7 @@ class SegmentHA_T1(FSCommand):
186194
output_spec = SegmentHA_T1OutputSpec
187195

188196
def _run_interface(self, runtime):
197+
_set_freesurfer_seed(runtime)
189198
fs_path = os.path.join(self.inputs.subjects_dir, self.inputs.subject_id, 'mri')
190199
expected = [
191200
'lh.hippoAmygLabels-T1.v22.FSvoxelSpace.mgz',
@@ -324,6 +333,7 @@ class SegmentGTM(GTMSeg):
324333
"""Run ``gtmseg`` unless outputs already exist."""
325334

326335
def _run_interface(self, runtime):
336+
_set_freesurfer_seed(runtime)
327337
subj_dir = Path(self.inputs.subjects_dir) / self.inputs.subject_id
328338
seg_file = subj_dir / 'mri' / self.inputs.out_file
329339
stats_file = subj_dir / 'stats' / Path(self.inputs.out_file).with_suffix('.stats').name
@@ -563,6 +573,10 @@ def _list_outputs(self):
563573
outputs[name] = os.path.abspath(value)
564574
return outputs
565575

576+
def _run_interface(self, runtime):
577+
_set_freesurfer_seed(runtime)
578+
return super()._run_interface(runtime)
579+
566580
def _format_arg(self, name, spec, value):
567581
if name in ('summary_file', 'ctab_out_file', 'avgwf_txt_file'):
568582
if not isinstance(value, bool):

petprep/interfaces/tests/test_segmentation_interface.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from pathlib import Path
2+
from types import SimpleNamespace
23

3-
from ..segmentation import SegmentBS, SegmentGTM, SegmentWM
4+
from ... import config
5+
from ..segmentation import MRISclimbicSeg, SegmentBS, SegmentGTM, SegmentWM, _set_freesurfer_seed
46

57

68
def test_segmentgtm_skip(tmp_path):
@@ -15,6 +17,24 @@ def test_segmentgtm_skip(tmp_path):
1517

1618
assert res.runtime.returncode == 0
1719
assert Path(res.outputs.out_file) == subj_dir / 'mri' / 'gtmseg.mgz'
20+
assert res.runtime.environ['FREESURFER_RANDOM_SEED'] == str(config.seeds.freesurfer)
21+
22+
23+
def test_mrisclimbicseg_seed(tmp_path):
24+
subjects_dir = tmp_path / 'subjects'
25+
subject_dir = subjects_dir / 'sub-01'
26+
subject_dir.mkdir(parents=True)
27+
28+
out_file = subject_dir / 'sub-01_sclimbic.nii.gz'
29+
out_stats = subject_dir / 'sub-01_sclimbic.stats'
30+
out_file.write_text('')
31+
out_stats.write_text('')
32+
33+
seg = MRISclimbicSeg(out_file=str(out_file), sd=str(subjects_dir), subjects=['sub-01'])
34+
res = seg.run()
35+
36+
assert res.runtime.returncode == 0
37+
assert res.runtime.environ['FREESURFER_RANDOM_SEED'] == str(config.seeds.freesurfer)
1838

1939

2040
def _fake_bs_run(self, cmd):
@@ -47,3 +67,11 @@ def test_segmentwm_stdout_stderr(monkeypatch, tmp_path):
4767
res = seg.run()
4868
assert res.outputs.stdout == 'wm out'
4969
assert res.outputs.stderr == 'wm err'
70+
71+
72+
def test_set_freesurfer_seed_runtime():
73+
runtime = SimpleNamespace(environ={})
74+
75+
runtime = _set_freesurfer_seed(runtime)
76+
77+
assert runtime.environ['FREESURFER_RANDOM_SEED'] == str(config.seeds.freesurfer)

petprep/tests/test_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,21 @@ def test_config_spaces():
104104

105105

106106
@pytest.mark.parametrize(
107-
('master_seed', 'ants_seed', 'numpy_seed'), [(1, 17612, 8272), (100, 19094, 60232)]
107+
('master_seed', 'ants_seed', 'numpy_seed', 'freesurfer_seed'),
108+
[(1, 17612, 8272, 33433), (100, 19094, 60232, 59629)],
108109
)
109-
def test_prng_seed(master_seed, ants_seed, numpy_seed):
110+
def test_prng_seed(master_seed, ants_seed, numpy_seed, freesurfer_seed):
110111
"""Ensure seeds are properly tracked"""
111112
seeds = config.seeds
112113
with patch.dict(os.environ, {}):
113114
seeds.load({'_random_seed': master_seed}, init=True)
114115
assert seeds.master == master_seed
115116
assert seeds.ants == ants_seed
116117
assert seeds.numpy == numpy_seed
118+
assert seeds.freesurfer == freesurfer_seed
117119
assert os.getenv('ANTS_RANDOM_SEED') == str(ants_seed)
120+
assert os.getenv('FREESURFER_RANDOM_SEED') == str(freesurfer_seed)
118121

119122
_reset_config()
120-
for seed in ('_random_seed', 'master', 'ants', 'numpy'):
123+
for seed in ('_random_seed', 'master', 'ants', 'numpy', 'freesurfer'):
121124
assert getattr(config.seeds, seed) is None

0 commit comments

Comments
 (0)