Skip to content

Commit d945730

Browse files
authored
Fix babs init --list_sub_file (#345)
1 parent 1a10cdd commit d945730

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

babs/bootstrap.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,13 @@ def babs_bootstrap(
256256

257257
# Copy in any other files needed:
258258
self._init_import_files(container.config.get('imported_files', []))
259-
# Create the inclusion file
260-
self._update_inclusion_dataframe(initial_inclusion_df)
259+
# _update_inclusion_dataframe() expects a DataFrame (or None).
260+
# If --list_sub_file was provided, use the parsed DataFrame
261+
# stored in initial_inclu_df by set_inclusion_dataframe() above.
262+
inclusion_df_for_update = (
263+
self.input_datasets.initial_inclu_df if initial_inclusion_df is not None else None
264+
)
265+
self._update_inclusion_dataframe(inclusion_df_for_update)
261266

262267
# Generate the template of job submission: --------------------------------
263268
print('\nGenerating templates for job submission calls...')

tests/test_babs_workflow.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import os
55
import os.path as op
66
import time
7+
from glob import glob
78
from pathlib import Path
89
from unittest import mock
910

11+
import pandas as pd
1012
import pytest
1113
from conftest import (
1214
ensure_container_image,
@@ -194,6 +196,87 @@ def _get_results_branches_use_merge_ds_when_exists(self):
194196
raise
195197

196198

199+
@pytest.mark.parametrize('processing_level', ['subject', 'session'])
200+
def test_babs_init_list_sub_file(
201+
tmp_path_factory,
202+
templateflow_home,
203+
bids_data_singlesession,
204+
bids_data_multisession,
205+
processing_level,
206+
simbids_container_ds,
207+
):
208+
"""Test `babs init` with --list_sub_file: inclusion list matches the provided CSV."""
209+
os.environ['TEMPLATEFLOW_HOME'] = str(templateflow_home)
210+
211+
project_base = tmp_path_factory.mktemp('project')
212+
project_root = project_base / 'my_babs_project'
213+
container_name = 'simbids-0-0-3'
214+
215+
bids_path = (
216+
bids_data_multisession if processing_level == 'session' else bids_data_singlesession
217+
)
218+
sub_dirs = sorted(glob(op.join(bids_path, 'sub-*')))
219+
assert len(sub_dirs) >= 2, 'Need at least 2 subjects in BIDS path'
220+
221+
# Create a small list CSV with 2 rows (IDs must exist in BIDS)
222+
if processing_level == 'session':
223+
rows = []
224+
for sub_dir in sub_dirs:
225+
for ses_dir in sorted(glob(op.join(sub_dir, 'ses-*'))):
226+
rows.append({'sub_id': op.basename(sub_dir), 'ses_id': op.basename(ses_dir)})
227+
if len(rows) >= 2:
228+
break
229+
if len(rows) >= 2:
230+
break
231+
assert len(rows) >= 2, 'Need at least 2 session rows in BIDS path'
232+
else:
233+
rows = [
234+
{'sub_id': op.basename(sub_dirs[0])},
235+
{'sub_id': op.basename(sub_dirs[1])},
236+
]
237+
238+
list_df = pd.DataFrame(rows[:2])
239+
list_sub_file = project_base / 'list_sub.csv'
240+
list_df.to_csv(list_sub_file, index=False)
241+
242+
config_simbids_path = get_config_simbids_path()
243+
container_config = update_yaml_for_run(
244+
project_base,
245+
config_simbids_path.name,
246+
{'BIDS': str(bids_path)},
247+
)
248+
249+
babs_init_opts = argparse.Namespace(
250+
project_root=project_root,
251+
list_sub_file=str(list_sub_file),
252+
container_ds=simbids_container_ds,
253+
container_name=container_name,
254+
container_config=str(container_config),
255+
processing_level=processing_level,
256+
queue='slurm',
257+
keep_if_failed=True,
258+
)
259+
with mock.patch.object(argparse.ArgumentParser, 'parse_args', return_value=babs_init_opts):
260+
_enter_init()
261+
262+
assert project_root.exists()
263+
inclusion_csv = project_root / 'analysis' / 'code' / 'processing_inclusion.csv'
264+
assert inclusion_csv.exists()
265+
df = pd.read_csv(inclusion_csv)
266+
assert 'sub_id' in df.columns
267+
assert len(df) == len(list_df)
268+
if processing_level == 'session':
269+
assert 'ses_id' in df.columns
270+
df_sorted = df.sort_values(['sub_id', 'ses_id']).reset_index(drop=True)
271+
list_sorted = list_df.sort_values(['sub_id', 'ses_id']).reset_index(drop=True)
272+
pd.testing.assert_frame_equal(df_sorted, list_sorted)
273+
else:
274+
pd.testing.assert_series_equal(
275+
df['sub_id'].sort_values().reset_index(drop=True),
276+
list_df['sub_id'].sort_values().reset_index(drop=True),
277+
)
278+
279+
197280
def test_bootstrap_cleanup(babs_project_sessionlevel_babsobject):
198281
"""Test that the cleanup method properly removes a partially created project."""
199282

0 commit comments

Comments
 (0)