|
1 | 1 | """Test the check_setup functionality.""" |
2 | 2 |
|
| 3 | +import os |
3 | 4 | import os.path as op |
4 | 5 | import random |
| 6 | +import re |
5 | 7 | from pathlib import Path |
6 | 8 | from unittest.mock import MagicMock |
7 | 9 |
|
@@ -202,3 +204,54 @@ def test_key_info_full(babs_project_sessionlevel): |
202 | 204 | babs_proj.wtf_key_info(flag_output_ria_only=False) |
203 | 205 | assert babs_proj.output_ria_data_dir is not None |
204 | 206 | assert babs_proj.analysis_dataset_id is not None |
| 207 | + |
| 208 | + |
| 209 | +@pytest.mark.parametrize( |
| 210 | + ('throttle_value', 'expected_in_template'), |
| 211 | + [(10, True), (None, False)], |
| 212 | +) |
| 213 | +def test_throttle_in_job_template( |
| 214 | + tmp_path_factory, |
| 215 | + templateflow_home, |
| 216 | + simbids_container_ds, |
| 217 | + bids_data_singlesession, |
| 218 | + throttle_value, |
| 219 | + expected_in_template, |
| 220 | +): |
| 221 | + """Test that throttle value is correctly included in job submission template.""" |
| 222 | + from conftest import get_config_simbids_path, update_yaml_for_run |
| 223 | + |
| 224 | + from babs.bootstrap import BABSBootstrap |
| 225 | + |
| 226 | + os.environ['TEMPLATEFLOW_HOME'] = str(templateflow_home) |
| 227 | + |
| 228 | + project_base = tmp_path_factory.mktemp('project') |
| 229 | + project_root = project_base / f'my_babs_project_{throttle_value or "none"}' |
| 230 | + container_config = update_yaml_for_run( |
| 231 | + project_base, |
| 232 | + get_config_simbids_path().name, |
| 233 | + {'BIDS': bids_data_singlesession}, |
| 234 | + ) |
| 235 | + |
| 236 | + babs_bootstrap = BABSBootstrap(project_root=project_root) |
| 237 | + babs_bootstrap.babs_bootstrap( |
| 238 | + processing_level='subject', |
| 239 | + queue='slurm', |
| 240 | + container_ds=simbids_container_ds, |
| 241 | + container_name='simbids-0-0-3', |
| 242 | + container_config=container_config, |
| 243 | + initial_inclusion_df=None, |
| 244 | + throttle=throttle_value, |
| 245 | + ) |
| 246 | + |
| 247 | + assert babs_bootstrap.throttle == throttle_value |
| 248 | + |
| 249 | + template_path = op.join(babs_bootstrap.analysis_path, 'code', 'submit_job_template.yaml') |
| 250 | + with open(template_path) as f: |
| 251 | + cmd_template = yaml.safe_load(f)['cmd_template'] |
| 252 | + |
| 253 | + assert '--array=1-${max_array}' in cmd_template |
| 254 | + if expected_in_template: |
| 255 | + assert f'%{throttle_value}' in cmd_template |
| 256 | + else: |
| 257 | + assert not re.search(r'%\d+', cmd_template) |
0 commit comments