Skip to content

Commit 261c129

Browse files
committed
MaD generator: add single file mode
1 parent 2818e6e commit 261c129

File tree

3 files changed

+47
-37
lines changed

3 files changed

+47
-37
lines changed

misc/scripts/models-as-data/bulk_generate_mad.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Note: This file must be formatted using the Black Python formatter.
66
"""
77

8-
import os.path
8+
import pathlib
99
import subprocess
1010
import sys
1111
from typing import Required, TypedDict, List, Callable, Optional
@@ -41,7 +41,7 @@ def missing_module(module_name: str) -> None:
4141
.decode("utf-8")
4242
.strip()
4343
)
44-
build_dir = os.path.join(gitroot, "mad-generation-build")
44+
build_dir = pathlib.Path(gitroot, "mad-generation-build")
4545

4646

4747
# A project to generate models for
@@ -86,10 +86,10 @@ def clone_project(project: Project) -> str:
8686
git_tag = project.get("git-tag")
8787

8888
# Determine target directory
89-
target_dir = os.path.join(build_dir, name)
89+
target_dir = build_dir / name
9090

9191
# Clone only if directory doesn't already exist
92-
if not os.path.exists(target_dir):
92+
if not target_dir.exists():
9393
if git_tag:
9494
print(f"Cloning {name} from {repo_url} at tag {git_tag}")
9595
else:
@@ -191,10 +191,10 @@ def build_database(
191191
name = project["name"]
192192

193193
# Create database directory path
194-
database_dir = os.path.join(build_dir, f"{name}-db")
194+
database_dir = build_dir / f"{name}-db"
195195

196196
# Only build the database if it doesn't already exist
197-
if not os.path.exists(database_dir):
197+
if not database_dir.exists():
198198
print(f"Building CodeQL database for {name}...")
199199
extractor_options = [option for x in extractor_options for option in ("-O", x)]
200200
try:
@@ -241,7 +241,11 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
241241
generator.with_summaries = should_generate_summaries(project)
242242
generator.threads = args.codeql_threads
243243
generator.ram = args.codeql_ram
244-
generator.setenvironment(database=database_dir, folder=name)
244+
if config.get("single-file", False):
245+
generator.single_file = name
246+
else:
247+
generator.folder = name
248+
generator.setenvironment(database=database_dir)
245249
generator.run()
246250

247251

@@ -312,20 +316,14 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
312316
if response.status_code != 200:
313317
print(f"Failed to download file. Status code: {response.status_code}")
314318
sys.exit(1)
315-
target_zip = os.path.join(build_dir, zipName)
319+
target_zip = build_dir / zipName
316320
with open(target_zip, "wb") as file:
317321
for chunk in response.iter_content(chunk_size=8192):
318322
file.write(chunk)
319323
print(f"Download complete: {target_zip}")
320324
return target_zip
321325

322326

323-
def remove_extension(filename: str) -> str:
324-
while "." in filename:
325-
filename, _ = os.path.splitext(filename)
326-
return filename
327-
328-
329327
def pretty_name_from_artifact_name(artifact_name: str) -> str:
330328
return artifact_name.split("___")[1]
331329

@@ -399,19 +397,17 @@ def download_and_decompress(analyzed_database: dict) -> str:
399397
# The database is in a zip file, which contains a tar.gz file with the DB
400398
# First we open the zip file
401399
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
402-
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
400+
artifact_unzipped_location = build_dir / artifact_name
403401
# clean up any remnants of previous runs
404402
shutil.rmtree(artifact_unzipped_location, ignore_errors=True)
405403
# And then we extract it to build_dir/artifact_name
406404
zip_ref.extractall(artifact_unzipped_location)
407405
# And then we extract the language tar.gz file inside it
408-
artifact_tar_location = os.path.join(
409-
artifact_unzipped_location, f"{language}.tar.gz"
410-
)
406+
artifact_tar_location = artifact_unzipped_location / f"{language}.tar.gz"
411407
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
412408
# And we just untar it to the same directory as the zip file
413409
tar_ref.extractall(artifact_unzipped_location)
414-
ret = os.path.join(artifact_unzipped_location, language)
410+
ret = artifact_unzipped_location / language
415411
print(f"Decompression complete: {ret}")
416412
return ret
417413

@@ -431,8 +427,16 @@ def download_and_decompress(analyzed_database: dict) -> str:
431427
return [(project_map[n], r) for n, r in zip(analyzed_databases, results)]
432428

433429

434-
def get_mad_destination_for_project(config, name: str) -> str:
435-
return os.path.join(config["destination"], name)
430+
def clean_up_mad_destination_for_project(config, name: str):
431+
target = pathlib.Path(config["destination"], name)
432+
if config.get("single-file", False):
433+
target = target.with_suffix(".model.yml")
434+
if target.exists():
435+
print(f"Deleting existing MaD file at {target}")
436+
target.unlink()
437+
elif target.exists():
438+
print(f"Deleting existing MaD directory at {target}")
439+
shutil.rmtree(target, ignore_errors=True)
436440

437441

438442
def get_strategy(config) -> str:
@@ -454,8 +458,7 @@ def main(config, args) -> None:
454458
language = config["language"]
455459

456460
# Create build directory if it doesn't exist
457-
if not os.path.exists(build_dir):
458-
os.makedirs(build_dir)
461+
build_dir.mkdir(parents=True, exist_ok=True)
459462

460463
database_results = []
461464
match get_strategy(config):
@@ -475,7 +478,7 @@ def main(config, args) -> None:
475478
if args.pat is None:
476479
print("ERROR: --pat argument is required for DCA strategy")
477480
sys.exit(1)
478-
if not os.path.exists(args.pat):
481+
if not args.pat.exists():
479482
print(f"ERROR: Personal Access Token file '{pat}' does not exist.")
480483
sys.exit(1)
481484
with open(args.pat, "r") as f:
@@ -499,12 +502,9 @@ def main(config, args) -> None:
499502
)
500503
sys.exit(1)
501504

502-
# Delete the MaD directory for each project
503-
for project, database_dir in database_results:
504-
mad_dir = get_mad_destination_for_project(config, project["name"])
505-
if os.path.exists(mad_dir):
506-
print(f"Deleting existing MaD directory at {mad_dir}")
507-
subprocess.check_call(["rm", "-rf", mad_dir])
505+
# clean up existing MaD data for the projects
506+
for project, _ in database_results:
507+
clean_up_mad_destination_for_project(config, project["name"])
508508

509509
for project, database_dir in database_results:
510510
if database_dir is not None:
@@ -514,7 +514,10 @@ def main(config, args) -> None:
514514
if __name__ == "__main__":
515515
parser = argparse.ArgumentParser()
516516
parser.add_argument(
517-
"--config", type=str, help="Path to the configuration file.", required=True
517+
"--config",
518+
type=pathlib.Path,
519+
help="Path to the configuration file.",
520+
required=True,
518521
)
519522
parser.add_argument(
520523
"--dca",
@@ -525,7 +528,7 @@ def main(config, args) -> None:
525528
)
526529
parser.add_argument(
527530
"--pat",
528-
type=str,
531+
type=pathlib.Path,
529532
help="Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)",
530533
)
531534
parser.add_argument(
@@ -544,7 +547,7 @@ def main(config, args) -> None:
544547

545548
# Load config file
546549
config = {}
547-
if not os.path.exists(args.config):
550+
if not args.config.exists():
548551
print(f"ERROR: Config file '{args.config}' does not exist.")
549552
sys.exit(1)
550553
try:

misc/scripts/models-as-data/generate_mad.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,21 @@ class Generator:
5353
ram = None
5454
threads = 0
5555
folder = ""
56+
single_file = None
5657

5758
def __init__(self, language=None):
5859
self.language = language
5960

6061
def setenvironment(self, database=None, folder=None):
61-
self.codeQlRoot = (
62+
self.codeql_root = (
6263
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
6364
.decode("utf-8")
6465
.strip()
6566
)
6667
self.database = database or self.database
6768
self.folder = folder or self.folder
6869
self.generated_frameworks = os.path.join(
69-
self.codeQlRoot, f"{self.language}/ql/lib/ext/generated/{self.folder}"
70+
self.codeql_root, f"{self.language}/ql/lib/ext/generated/{self.folder}"
7071
)
7172
self.workDir = tempfile.mkdtemp()
7273
if self.ram is None:
@@ -134,6 +135,10 @@ def make():
134135
type=int,
135136
help="Amount of RAM to use for CodeQL queries in MB. Default is to use 2048 MB per thread.",
136137
)
138+
p.add_argument(
139+
"--single-file",
140+
help="Generate a single file with all models instead of separate files for each namespace, using provided argument as the base filename.",
141+
)
137142
generator = p.parse_args(namespace=Generator())
138143

139144
if (
@@ -154,7 +159,7 @@ def make():
154159
def runQuery(self, query):
155160
print("########## Querying " + query + "...")
156161
queryFile = os.path.join(
157-
self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query
162+
self.codeql_root, f"{self.language}/ql/src/utils/{self.dirname}", query
158163
)
159164
resultBqrs = os.path.join(self.workDir, "out.bqrs")
160165

@@ -187,6 +192,8 @@ def asAddsTo(self, rows, predicate):
187192
def getAddsTo(self, query, predicate):
188193
data = self.runQuery(query)
189194
rows = parseData(data)
195+
if self.single_file and rows:
196+
rows = {self.single_file: "".join(rows.values())}
190197
return self.asAddsTo(rows, predicate)
191198

192199
def makeContent(self):

misc/scripts/models-as-data/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def remove_dir(dirName):
2222

2323

2424
def run_cmd(cmd, msg="Failed to run command"):
25-
print("Running " + " ".join(cmd))
25+
print("Running " + " ".join(map(str, cmd)))
2626
if subprocess.check_call(cmd):
2727
print(msg)
2828
exit(1)

0 commit comments

Comments
 (0)