Skip to content

Commit 7ecf8c8

Browse files
committed
Bulk generator: Format file and add a note at the top of the file specifying the formatting requirements.
1 parent cb93870 commit 7ecf8c8

File tree

1 file changed

+117
-54
lines changed

1 file changed

+117
-54
lines changed

misc/scripts/models-as-data/bulk_generate_mad.py

Lines changed: 117 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
22
Experimental script for bulk generation of MaD models based on a list of projects.
3+
4+
Note: This file must be formatted using the Black Python formatter.
35
"""
46

57
import os.path
@@ -24,6 +26,7 @@
2426
)
2527
build_dir = os.path.join(gitroot, "mad-generation-build")
2628

29+
2730
# A project to generate models for
2831
class Project(TypedDict):
2932
"""
@@ -132,7 +135,9 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
132135
return project_dirs
133136

134137

135-
def build_database(language: str, extractor_options, project: Project, project_dir: str) -> str | None:
138+
def build_database(
139+
language: str, extractor_options, project: Project, project_dir: str
140+
) -> str | None:
136141
"""
137142
Build a CodeQL database for a project.
138143
@@ -179,6 +184,7 @@ def build_database(language: str, extractor_options, project: Project, project_d
179184

180185
return database_dir
181186

187+
182188
def generate_models(args, name: str, database_dir: str) -> None:
183189
"""
184190
Generate models for a project.
@@ -196,7 +202,10 @@ def generate_models(args, name: str, database_dir: str) -> None:
196202
generator.setenvironment(database=database_dir, folder=name)
197203
generator.run()
198204

199-
def build_databases_from_projects(language: str, extractor_options, projects: List[Project]) -> List[tuple[str, str | None]]:
205+
206+
def build_databases_from_projects(
207+
language: str, extractor_options, projects: List[Project]
208+
) -> List[tuple[str, str | None]]:
200209
"""
201210
Build databases for all projects in parallel.
202211
@@ -215,11 +224,15 @@ def build_databases_from_projects(language: str, extractor_options, projects: Li
215224
# Phase 2: Build databases for all projects
216225
print("\n=== Phase 2: Building databases ===")
217226
database_results = [
218-
(project["name"], build_database(language, extractor_options, project, project_dir))
227+
(
228+
project["name"],
229+
build_database(language, extractor_options, project, project_dir),
230+
)
219231
for project, project_dir in project_dirs
220232
]
221233
return database_results
222234

235+
223236
def github(url: str, pat: str, extra_headers: dict[str, str] = {}) -> dict:
224237
"""
225238
Download a JSON file from GitHub using a personal access token (PAT).
@@ -230,14 +243,15 @@ def github(url: str, pat: str, extra_headers: dict[str, str] = {}) -> dict:
230243
Returns:
231244
The JSON response as a dictionary.
232245
"""
233-
headers = { "Authorization": f"token {pat}" } | extra_headers
246+
headers = {"Authorization": f"token {pat}"} | extra_headers
234247
response = requests.get(url, headers=headers)
235248
if response.status_code != 200:
236249
print(f"Failed to download JSON: {response.status_code} {response.text}")
237250
sys.exit(1)
238251
else:
239252
return response.json()
240253

254+
241255
def download_artifact(url: str, artifact_name: str, pat: str) -> str:
242256
"""
243257
Download a GitHub Actions artifact from a given URL.
@@ -248,7 +262,7 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
248262
Returns:
249263
The path to the downloaded artifact file.
250264
"""
251-
headers = { "Authorization": f"token {pat}", "Accept": "application/vnd.github+json" }
265+
headers = {"Authorization": f"token {pat}", "Accept": "application/vnd.github+json"}
252266
response = requests.get(url, stream=True, headers=headers)
253267
zipName = artifact_name + ".zip"
254268
if response.status_code == 200:
@@ -262,15 +276,20 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
262276
print(f"Failed to download file. Status code: {response.status_code}")
263277
sys.exit(1)
264278

279+
265280
def remove_extension(filename: str) -> str:
266281
while "." in filename:
267282
filename, _ = os.path.splitext(filename)
268283
return filename
269284

285+
270286
def pretty_name_from_artifact_name(artifact_name: str) -> str:
271287
return artifact_name.split("___")[1]
272288

273-
def download_dca_databases(experiment_name: str, pat: str, projects) -> List[tuple[str, str | None]]:
289+
290+
def download_dca_databases(
291+
experiment_name: str, pat: str, projects
292+
) -> List[tuple[str, str | None]]:
274293
"""
275294
Download databases from a DCA experiment.
276295
Args:
@@ -282,58 +301,81 @@ def download_dca_databases(experiment_name: str, pat: str, projects) -> List[tup
282301
"""
283302
database_results = []
284303
print("\n=== Finding projects ===")
285-
response = github(f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json", pat)
304+
response = github(
305+
f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json",
306+
pat,
307+
)
286308
targets = response["targets"]
287309
for target, data in targets.items():
288-
downloads = data["downloads"]
289-
analyzed_database = downloads["analyzed_database"]
290-
artifact_name = analyzed_database["artifact_name"]
291-
pretty_name = pretty_name_from_artifact_name(artifact_name)
292-
293-
if not pretty_name in [project["name"] for project in projects]:
294-
print(f"Skipping {pretty_name} as it is not in the list of projects")
295-
continue
296-
297-
repository = analyzed_database["repository"]
298-
run_id = analyzed_database["run_id"]
299-
print(f"=== Finding artifact: {artifact_name} ===")
300-
response = github(f"https://api.github.com/repos/{repository}/actions/runs/{run_id}/artifacts", pat, { "Accept": "application/vnd.github+json" })
301-
artifacts = response["artifacts"]
302-
artifact_map = {artifact["name"]: artifact for artifact in artifacts}
303-
print(f"=== Downloading artifact: {artifact_name} ===")
304-
archive_download_url = artifact_map[artifact_name]["archive_download_url"]
305-
artifact_zip_location = download_artifact(archive_download_url, artifact_name, pat)
306-
print(f"=== Extracting artifact: {artifact_name} ===")
307-
# The database is in a zip file, which contains a tar.gz file with the DB
308-
# First we open the zip file
309-
with zipfile.ZipFile(artifact_zip_location, 'r') as zip_ref:
310-
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
311-
# And then we extract it to build_dir/artifact_name
312-
zip_ref.extractall(artifact_unzipped_location)
313-
# And then we iterate over the contents of the extracted directory
314-
# and extract the tar.gz files inside it
315-
for entry in os.listdir(artifact_unzipped_location):
316-
artifact_tar_location = os.path.join(artifact_unzipped_location, entry)
317-
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
318-
# And we just untar it to the same directory as the zip file
319-
tar_ref.extractall(artifact_unzipped_location)
320-
database_results.append((pretty_name, os.path.join(artifact_unzipped_location, remove_extension(entry))))
310+
downloads = data["downloads"]
311+
analyzed_database = downloads["analyzed_database"]
312+
artifact_name = analyzed_database["artifact_name"]
313+
pretty_name = pretty_name_from_artifact_name(artifact_name)
314+
315+
if not pretty_name in [project["name"] for project in projects]:
316+
print(f"Skipping {pretty_name} as it is not in the list of projects")
317+
continue
318+
319+
repository = analyzed_database["repository"]
320+
run_id = analyzed_database["run_id"]
321+
print(f"=== Finding artifact: {artifact_name} ===")
322+
response = github(
323+
f"https://api.github.com/repos/{repository}/actions/runs/{run_id}/artifacts",
324+
pat,
325+
{"Accept": "application/vnd.github+json"},
326+
)
327+
artifacts = response["artifacts"]
328+
artifact_map = {artifact["name"]: artifact for artifact in artifacts}
329+
print(f"=== Downloading artifact: {artifact_name} ===")
330+
archive_download_url = artifact_map[artifact_name]["archive_download_url"]
331+
artifact_zip_location = download_artifact(
332+
archive_download_url, artifact_name, pat
333+
)
334+
print(f"=== Extracting artifact: {artifact_name} ===")
335+
# The database is in a zip file, which contains a tar.gz file with the DB
336+
# First we open the zip file
337+
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
338+
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
339+
# And then we extract it to build_dir/artifact_name
340+
zip_ref.extractall(artifact_unzipped_location)
341+
# And then we iterate over the contents of the extracted directory
342+
# and extract the tar.gz files inside it
343+
for entry in os.listdir(artifact_unzipped_location):
344+
artifact_tar_location = os.path.join(artifact_unzipped_location, entry)
345+
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
346+
# And we just untar it to the same directory as the zip file
347+
tar_ref.extractall(artifact_unzipped_location)
348+
database_results.append(
349+
(
350+
pretty_name,
351+
os.path.join(
352+
artifact_unzipped_location, remove_extension(entry)
353+
),
354+
)
355+
)
321356
print(f"\n=== Extracted {len(database_results)} databases ===")
322357

323358
def compare(a, b):
324-
a_index = next(i for i, project in enumerate(projects) if project["name"] == a[0])
325-
b_index = next(i for i, project in enumerate(projects) if project["name"] == b[0])
359+
a_index = next(
360+
i for i, project in enumerate(projects) if project["name"] == a[0]
361+
)
362+
b_index = next(
363+
i for i, project in enumerate(projects) if project["name"] == b[0]
364+
)
326365
return a_index - b_index
327366

328367
# Sort the database results based on the order in the projects file
329368
return sorted(database_results, key=cmp_to_key(compare))
330-
369+
370+
331371
def get_destination_for_project(config, name: str) -> str:
332372
return os.path.join(config["destination"], name)
333373

374+
334375
def get_strategy(config) -> str:
335376
return config["strategy"].lower()
336377

378+
337379
def main(config, args) -> None:
338380
"""
339381
Main function to handle the bulk generation of MaD models.
@@ -371,7 +413,9 @@ def main(config, args) -> None:
371413
match get_strategy(config):
372414
case "repo":
373415
extractor_options = config.get("extractor_options", [])
374-
database_results = build_databases_from_projects(language, extractor_options, projects)
416+
database_results = build_databases_from_projects(
417+
language, extractor_options, projects
418+
)
375419
case "dca":
376420
experiment_name = args.dca
377421
if experiment_name is None:
@@ -386,9 +430,7 @@ def main(config, args) -> None:
386430
# Phase 3: Generate models for all projects
387431
print("\n=== Phase 3: Generating models ===")
388432

389-
failed_builds = [
390-
project for project, db_dir in database_results if db_dir is None
391-
]
433+
failed_builds = [project for project, db_dir in database_results if db_dir is None]
392434
if failed_builds:
393435
print(
394436
f"ERROR: {len(failed_builds)} database builds failed: {', '.join(failed_builds)}"
@@ -406,15 +448,36 @@ def main(config, args) -> None:
406448
if database_dir is not None:
407449
generate_models(args, project, database_dir)
408450

451+
409452
if __name__ == "__main__":
410453
parser = argparse.ArgumentParser()
411-
parser.add_argument("--config", type=str, help="Path to the configuration file.", required=True)
412-
parser.add_argument("--dca", type=str, help="Name of a DCA run that built all the projects", required=False)
413-
parser.add_argument("--pat", type=str, help="PAT token to grab DCA databases (the same as the one you use for DCA)", required=False)
414-
parser.add_argument("--lang", type=str, help="The language to generate models for", required=True)
415-
parser.add_argument("--with-sources", action="store_true", help="Generate sources", required=False)
416-
parser.add_argument("--with-sinks", action="store_true", help="Generate sinks", required=False)
417-
parser.add_argument("--with-summaries", action="store_true", help="Generate sinks", required=False)
454+
parser.add_argument(
455+
"--config", type=str, help="Path to the configuration file.", required=True
456+
)
457+
parser.add_argument(
458+
"--dca",
459+
type=str,
460+
help="Name of a DCA run that built all the projects",
461+
required=False,
462+
)
463+
parser.add_argument(
464+
"--pat",
465+
type=str,
466+
help="PAT token to grab DCA databases (the same as the one you use for DCA)",
467+
required=False,
468+
)
469+
parser.add_argument(
470+
"--lang", type=str, help="The language to generate models for", required=True
471+
)
472+
parser.add_argument(
473+
"--with-sources", action="store_true", help="Generate sources", required=False
474+
)
475+
parser.add_argument(
476+
"--with-sinks", action="store_true", help="Generate sinks", required=False
477+
)
478+
parser.add_argument(
479+
"--with-summaries", action="store_true", help="Generate sinks", required=False
480+
)
418481
args = parser.parse_args()
419482

420483
# Load config file

0 commit comments

Comments
 (0)