Skip to content

Commit 9b3b246

Browse files
committed
Add handling of the folder suffix and pipeline choice
1 parent 4bdbb3e commit 9b3b246

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

derotation/derotate_batch.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from pathlib import Path
1212

1313
from derotation.analysis.full_derotation_pipeline import FullPipeline
14+
from derotation.analysis.incremental_derotation_pipeline import (
15+
IncrementalPipeline,
16+
)
1417
from derotation.config.load_config import load_config, update_config_paths
1518

1619

@@ -20,9 +23,12 @@ def derotate(
2023
path_to_stimulus_randperm: str,
2124
glob_naming_pattern_tif: str,
2225
glob_naming_pattern_bin: str,
23-
) -> float:
26+
folder_suffix: str = "full",
27+
):
2428
"""
25-
Run the full derotation pipeline on a single dataset.
29+
Run the derotation pipeline on a single dataset. This function is
30+
responsible for loading the configuration, updating the paths, the
31+
pipeline choice, and running the pipeline.
2632
2733
Parameters
2834
----------
@@ -36,11 +42,9 @@ def derotate(
3642
The glob naming pattern for the tif file.
3743
glob_naming_pattern_bin : str
3844
The glob naming pattern for the bin file.
39-
40-
Returns
41-
-------
42-
float
43-
The metric calculated by the pipeline.
45+
folder_suffix : str, optional
46+
The suffix to append to the output folder name, by default "full".
47+
This is used to differentiate between full and incremental pipelines.
4448
4549
Raises
4650
------
@@ -60,16 +64,18 @@ def derotate(
6064
aux_path=str(bin_path),
6165
stim_randperm_path=str(path_to_stimulus_randperm),
6266
output_folder=output_folder,
63-
folder_suffix="full",
67+
folder_suffix=folder_suffix,
6468
)
6569

6670
logging.info("Running full derotation pipeline")
6771

6872
# Run the pipeline
6973
try:
70-
derotator = FullPipeline(config)
74+
if folder_suffix == "full":
75+
derotator = FullPipeline(config)
76+
else:
77+
derotator = IncrementalPipeline(config)
7178
derotator()
72-
return derotator.metric
7379
except Exception as e:
7480
logging.error("Full derotation pipeline failed")
7581
logging.error(e.args)

0 commit comments

Comments
 (0)