Skip to content

Commit 46a9a24

Browse files
committed
chore: Plumb up the cli args
1 parent 9ea1802 commit 46a9a24

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

scripts/fetch_test_data.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def process_sample_data_request(
109109
request: DataRequest,
110110
decimate: bool,
111111
output_directory: Path,
112+
n_jobs: int | None = -1,
112113
) -> pd.DataFrame:
113114
"""
114115
Fetch and create sample datasets
@@ -125,6 +126,9 @@ def process_sample_data_request(
125126
Whether to decimate the datasets
126127
output_directory
127128
The directory to write the output to
129+
n_jobs
130+
Number of jobs to run in parallel
131+
If None, run sequentially.
128132
129133
Returns
130134
-------
@@ -134,10 +138,17 @@ def process_sample_data_request(
134138
datasets = request.fetch_datasets()
135139

136140
# Process all the datasets in parallel
137-
items = joblib.Parallel(n_jobs=-1)(
138-
joblib.delayed(_process_dataset)(processed_datasets, dataset, request, decimate, output_directory)
139-
for _, dataset in datasets.iterrows()
140-
)
141+
if n_jobs is None:
142+
logger.info("Processing datasets sequentially as n_jobs is None")
143+
items = [
144+
_process_dataset(processed_datasets, dataset, request, decimate, output_directory)
145+
for _, dataset in datasets.iterrows()
146+
]
147+
else:
148+
items = joblib.Parallel(n_jobs=n_jobs)(
149+
joblib.delayed(_process_dataset)(processed_datasets, dataset, request, decimate, output_directory)
150+
for _, dataset in datasets.iterrows()
151+
)
141152
# Flatten the list of lists
142153
items = [item for sublist in items for item in sublist]
143154

@@ -365,10 +376,16 @@ def create_sample_data(
365376
decimate: bool = True,
366377
output: Path = OUTPUT_PATH,
367378
only: list[str] | None = None,
379+
n_jobs: int = -1,
380+
run_sequentially: bool = False,
368381
) -> None:
369382
"""Fetch and create sample datasets"""
370383
processed_datasets = pd.DataFrame(columns=["source_type", "key", "files", "time_start", "time_end"])
371384

385+
if run_sequentially:
386+
n_jobs = None
387+
logger.info("Running in sequential mode, setting n_jobs to None")
388+
372389
for dataset_requested in DATASETS_TO_FETCH:
373390
if only:
374391
if dataset_requested.id not in only:
@@ -380,6 +397,7 @@ def create_sample_data(
380397
dataset_requested,
381398
decimate=decimate,
382399
output_directory=pathlib.Path(output),
400+
n_jobs=n_jobs,
383401
)
384402
# Remove duplicate source_type and key values, but keep the latest one
385403
processed_datasets = (

0 commit comments

Comments
 (0)