Skip to content

Commit 4f47ee2

Browse files
author
Paolo Tranquilli
committed
MaD: make bulk generator DCA strategy download DBs in parallel
1 parent fbd5058 commit 4f47ee2

File tree

1 file changed

+58
-38
lines changed

1 file changed

+58
-38
lines changed

misc/scripts/models-as-data/bulk_generate_mad.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os.path
99
import subprocess
1010
import sys
11-
from typing import NotRequired, TypedDict, List
11+
from typing import NotRequired, TypedDict, List, Callable, Optional
1212
from concurrent.futures import ThreadPoolExecutor, as_completed
1313
import time
1414
import argparse
@@ -111,6 +111,37 @@ def clone_project(project: Project) -> str:
111111
return target_dir
112112

113113

114+
def run_in_parallel[T, U](
115+
func: Callable[[T], U],
116+
items: List[T],
117+
*,
118+
on_error=lambda item, exc: None,
119+
error_summary=lambda failures: None,
120+
max_workers=8,
121+
) -> List[Optional[U]]:
122+
if not items:
123+
return []
124+
max_workers = min(max_workers, len(items))
125+
results = [None for _ in range(len(items))]
126+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
127+
# Start cloning tasks and keep track of them
128+
futures = {
129+
executor.submit(func, item): index for index, item in enumerate(items)
130+
}
131+
# Process results as they complete
132+
for future in as_completed(futures):
133+
index = futures[future]
134+
try:
135+
results[index] = future.result()
136+
except Exception as e:
137+
on_error(items[index], e)
138+
failed = [item for item, result in zip(items, results) if result is None]
139+
if failed:
140+
error_summary(failed)
141+
sys.exit(1)
142+
return results
143+
144+
114145
def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
115146
"""
116147
Clone all projects in parallel.
@@ -122,40 +153,19 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
122153
List of (project, project_dir) pairs in the same order as the input projects
123154
"""
124155
start_time = time.time()
125-
max_workers = min(8, len(projects)) # Use at most 8 threads
126-
project_dirs_map = {} # Map to store results by project name
127-
128-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
129-
# Start cloning tasks and keep track of them
130-
future_to_project = {
131-
executor.submit(clone_project, project): project for project in projects
132-
}
133-
134-
# Process results as they complete
135-
for future in as_completed(future_to_project):
136-
project = future_to_project[future]
137-
try:
138-
project_dir = future.result()
139-
project_dirs_map[project["name"]] = (project, project_dir)
140-
except Exception as e:
141-
print(f"ERROR: Failed to clone {project['name']}: {e}")
142-
143-
if len(project_dirs_map) != len(projects):
144-
failed_projects = [
145-
project["name"]
146-
for project in projects
147-
if project["name"] not in project_dirs_map
148-
]
149-
print(
150-
f"ERROR: Only {len(project_dirs_map)} out of {len(projects)} projects were cloned successfully. Failed projects: {', '.join(failed_projects)}"
151-
)
152-
sys.exit(1)
153-
154-
project_dirs = [project_dirs_map[project["name"]] for project in projects]
155-
156+
dirs = run_in_parallel(
157+
clone_project,
158+
projects,
159+
on_error=lambda project, exc: print(
160+
f"ERROR: Failed to clone project {project['name']}: {exc}"
161+
),
162+
error_summary=lambda failures: print(
163+
f"ERROR: Failed to clone {len(failures)} projects: {', '.join(p['name'] for p in failures)}"
164+
),
165+
)
156166
clone_time = time.time() - start_time
157167
print(f"Cloning completed in {clone_time:.2f} seconds")
158-
return project_dirs
168+
return list(zip(projects, dirs))
159169

160170

161171
def build_database(
@@ -352,7 +362,8 @@ def download_dca_databases(
352362

353363
artifact_map[pretty_name] = analyzed_database
354364

355-
for pretty_name, analyzed_database in artifact_map.items():
365+
def download(item: tuple[str, dict]) -> str:
366+
pretty_name, analyzed_database = item
356367
artifact_name = analyzed_database["artifact_name"]
357368
repository = analyzed_database["repository"]
358369
run_id = analyzed_database["run_id"]
@@ -383,13 +394,22 @@ def download_dca_databases(
383394
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
384395
# And we just untar it to the same directory as the zip file
385396
tar_ref.extractall(artifact_unzipped_location)
386-
database_results[pretty_name] = os.path.join(
387-
artifact_unzipped_location, remove_extension(entry)
388-
)
397+
return os.path.join(artifact_unzipped_location, remove_extension(entry))
398+
399+
results = run_in_parallel(
400+
download,
401+
list(artifact_map.items()),
402+
on_error=lambda item, exc: print(
403+
f"ERROR: Failed to download database for {item[0]}: {exc}"
404+
),
405+
error_summary=lambda failures: print(
406+
f"ERROR: Failed to download {len(failures)} databases: {', '.join(item[0] for item in failures)}"
407+
),
408+
)
389409

390410
print(f"\n=== Extracted {len(database_results)} databases ===")
391411

392-
return [(project, database_results[project["name"]]) for project in projects]
412+
return [(project_map[n], r) for n, r in zip(artifact_map, results)]
393413

394414

395415
def get_mad_destination_for_project(config, name: str) -> str:

0 commit comments

Comments
 (0)