Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions babs/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,13 @@ def babs_bootstrap(

# Copy in any other files needed:
self._init_import_files(container.config.get('imported_files', []))
# Create the inclusion file
self._update_inclusion_dataframe(initial_inclusion_df)
# _update_inclusion_dataframe() expects a DataFrame (or None).
# If --list_sub_file was provided, use the parsed DataFrame
# stored in initial_inclu_df by set_inclusion_dataframe() above.
inclusion_df_for_update = (
self.input_datasets.initial_inclu_df if initial_inclusion_df is not None else None
)
self._update_inclusion_dataframe(inclusion_df_for_update)

# Generate the template of job submission: --------------------------------
print('\nGenerating templates for job submission calls...')
Expand Down
83 changes: 83 additions & 0 deletions tests/test_babs_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import os
import os.path as op
import time
from glob import glob
from pathlib import Path
from unittest import mock

import pandas as pd
import pytest
from conftest import (
ensure_container_image,
Expand Down Expand Up @@ -194,6 +196,87 @@ def _get_results_branches_use_merge_ds_when_exists(self):
raise


@pytest.mark.parametrize('processing_level', ['subject', 'session'])
def test_babs_init_list_sub_file(
tmp_path_factory,
templateflow_home,
bids_data_singlesession,
bids_data_multisession,
processing_level,
simbids_container_ds,
):
"""Test `babs init` with --list_sub_file: inclusion list matches the provided CSV."""
os.environ['TEMPLATEFLOW_HOME'] = str(templateflow_home)

project_base = tmp_path_factory.mktemp('project')
project_root = project_base / 'my_babs_project'
container_name = 'simbids-0-0-3'

bids_path = (
bids_data_multisession if processing_level == 'session' else bids_data_singlesession
)
sub_dirs = sorted(glob(op.join(bids_path, 'sub-*')))
assert len(sub_dirs) >= 2, 'Need at least 2 subjects in BIDS path'

# Create a small list CSV with 2 rows (IDs must exist in BIDS)
if processing_level == 'session':
rows = []
for sub_dir in sub_dirs:
for ses_dir in sorted(glob(op.join(sub_dir, 'ses-*'))):
rows.append({'sub_id': op.basename(sub_dir), 'ses_id': op.basename(ses_dir)})
if len(rows) >= 2:
break
if len(rows) >= 2:
break
assert len(rows) >= 2, 'Need at least 2 session rows in BIDS path'
else:
rows = [
{'sub_id': op.basename(sub_dirs[0])},
{'sub_id': op.basename(sub_dirs[1])},
]

list_df = pd.DataFrame(rows[:2])
list_sub_file = project_base / 'list_sub.csv'
list_df.to_csv(list_sub_file, index=False)

config_simbids_path = get_config_simbids_path()
container_config = update_yaml_for_run(
project_base,
config_simbids_path.name,
{'BIDS': str(bids_path)},
)

babs_init_opts = argparse.Namespace(
project_root=project_root,
list_sub_file=str(list_sub_file),
container_ds=simbids_container_ds,
container_name=container_name,
container_config=str(container_config),
processing_level=processing_level,
queue='slurm',
keep_if_failed=True,
)
with mock.patch.object(argparse.ArgumentParser, 'parse_args', return_value=babs_init_opts):
_enter_init()

assert project_root.exists()
inclusion_csv = project_root / 'analysis' / 'code' / 'processing_inclusion.csv'
assert inclusion_csv.exists()
df = pd.read_csv(inclusion_csv)
assert 'sub_id' in df.columns
assert len(df) == len(list_df)
if processing_level == 'session':
assert 'ses_id' in df.columns
df_sorted = df.sort_values(['sub_id', 'ses_id']).reset_index(drop=True)
list_sorted = list_df.sort_values(['sub_id', 'ses_id']).reset_index(drop=True)
pd.testing.assert_frame_equal(df_sorted, list_sorted)
else:
pd.testing.assert_series_equal(
df['sub_id'].sort_values().reset_index(drop=True),
list_df['sub_id'].sort_values().reset_index(drop=True),
)


def test_bootstrap_cleanup(babs_project_sessionlevel_babsobject):
"""Test that the cleanup method properly removes a partially created project."""

Expand Down