Skip to content

Commit 8d03291

Browse files
authored
Merge pull request #487 from tsalo/fix-in-place-queries
FIX: Create copy of query before modifying
2 parents 29f0def + 126d5c6 commit 8d03291

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

src/smriprep/utils/bids.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def collect_derivatives(
5959
qry_base['session'] = session_id
6060

6161
for key, qry in spec['baseline'].items():
62-
qry |= qry_base
62+
qry = {**qry, **qry_base}
6363
item = layout.get(**qry)
6464
if not item:
6565
continue
@@ -76,7 +76,7 @@ def collect_derivatives(
7676
for _space in std_spaces:
7777
space = _space.replace(':cohort-', '+')
7878
for key, qry in spec['transforms'].items():
79-
qry |= qry_base
79+
qry = {**qry, **qry_base}
8080
qry['from'] = qry['from'] or space
8181
qry['to'] = qry['to'] or space
8282
item = layout.get(return_type='filename', **qry)
@@ -85,7 +85,7 @@ def collect_derivatives(
8585
transforms.setdefault(_space, {})[key] = item[0] if len(item) == 1 else item
8686

8787
for key, qry in spec['surfaces'].items():
88-
qry |= qry_base
88+
qry = {**qry, **qry_base}
8989
item = layout.get(return_type='filename', **qry)
9090
if not item or len(item) != 2:
9191
continue

src/smriprep/utils/tests/test_bids.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
import pytest
12
from niworkflows.utils.testing import generate_bids_skeleton
23

34
from ..bids import collect_derivatives
45
from . import DERIV_SKELETON
56

67

7-
def test_collect_derivatives(tmp_path):
8+
@pytest.fixture
9+
def deriv_dset(tmp_path):
810
deriv_dir = tmp_path / 'derivatives'
911
generate_bids_skeleton(deriv_dir, str(DERIV_SKELETON))
12+
return deriv_dir
13+
14+
15+
def test_collect_derivatives(deriv_dset):
1016
output_spaces = ['MNI152NLin2009cAsym', 'MNIPediatricAsym:cohort-3']
11-
collected = collect_derivatives(deriv_dir, '01', output_spaces)
17+
collected = collect_derivatives(deriv_dset, '01', output_spaces)
1218
for suffix in ('preproc', 'mask', 'dseg'):
1319
assert collected[f't1w_{suffix}']
1420
assert len(collected['t1w_tpms']) == 3
@@ -28,3 +34,14 @@ def test_collect_derivatives(tmp_path):
2834
'sphere_reg_msm',
2935
):
3036
assert len(collected[surface]) == 2
37+
38+
39+
def test_collect_derivatives_transforms(deriv_dset):
40+
"""Ensure transforms are collected for the right spaces."""
41+
output_spaces = ['MNI152NLin2009cAsym', 'MNIPediatricAsym:cohort-3']
42+
collected = collect_derivatives(deriv_dset, '01', output_spaces)
43+
xfms = collected['transforms']
44+
for space in output_spaces:
45+
template = space.split(':')[0]
46+
assert template in xfms[space]['reverse']
47+
assert template in xfms[space]['forward']

0 commit comments

Comments
 (0)