5
5
Note: This file must be formatted using the Black Python formatter.
6
6
"""
7
7
8
- import os . path
8
+ import pathlib
9
9
import subprocess
10
10
import sys
11
11
from typing import Required , TypedDict , List , Callable , Optional
@@ -41,7 +41,7 @@ def missing_module(module_name: str) -> None:
41
41
.decode ("utf-8" )
42
42
.strip ()
43
43
)
44
- build_dir = os . path . join (gitroot , "mad-generation-build" )
44
+ build_dir = pathlib . Path (gitroot , "mad-generation-build" )
45
45
46
46
47
47
# A project to generate models for
@@ -86,10 +86,10 @@ def clone_project(project: Project) -> str:
86
86
git_tag = project .get ("git-tag" )
87
87
88
88
# Determine target directory
89
- target_dir = os . path . join ( build_dir , name )
89
+ target_dir = build_dir / name
90
90
91
91
# Clone only if directory doesn't already exist
92
- if not os . path . exists (target_dir ):
92
+ if not target_dir . exists ():
93
93
if git_tag :
94
94
print (f"Cloning { name } from { repo_url } at tag { git_tag } " )
95
95
else :
@@ -191,10 +191,10 @@ def build_database(
191
191
name = project ["name" ]
192
192
193
193
# Create database directory path
194
- database_dir = os . path . join ( build_dir , f"{ name } -db" )
194
+ database_dir = build_dir / f"{ name } -db"
195
195
196
196
# Only build the database if it doesn't already exist
197
- if not os . path . exists (database_dir ):
197
+ if not database_dir . exists ():
198
198
print (f"Building CodeQL database for { name } ..." )
199
199
extractor_options = [option for x in extractor_options for option in ("-O" , x )]
200
200
try :
@@ -241,7 +241,11 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
241
241
generator .with_summaries = should_generate_summaries (project )
242
242
generator .threads = args .codeql_threads
243
243
generator .ram = args .codeql_ram
244
- generator .setenvironment (database = database_dir , folder = name )
244
+ if config .get ("single-file" , False ):
245
+ generator .single_file = name
246
+ else :
247
+ generator .folder = name
248
+ generator .setenvironment (database = database_dir )
245
249
generator .run ()
246
250
247
251
@@ -312,20 +316,14 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
312
316
if response .status_code != 200 :
313
317
print (f"Failed to download file. Status code: { response .status_code } " )
314
318
sys .exit (1 )
315
- target_zip = os . path . join ( build_dir , zipName )
319
+ target_zip = build_dir / zipName
316
320
with open (target_zip , "wb" ) as file :
317
321
for chunk in response .iter_content (chunk_size = 8192 ):
318
322
file .write (chunk )
319
323
print (f"Download complete: { target_zip } " )
320
324
return target_zip
321
325
322
326
323
- def remove_extension (filename : str ) -> str :
324
- while "." in filename :
325
- filename , _ = os .path .splitext (filename )
326
- return filename
327
-
328
-
329
327
def pretty_name_from_artifact_name (artifact_name : str ) -> str :
330
328
return artifact_name .split ("___" )[1 ]
331
329
@@ -399,19 +397,17 @@ def download_and_decompress(analyzed_database: dict) -> str:
399
397
# The database is in a zip file, which contains a tar.gz file with the DB
400
398
# First we open the zip file
401
399
with zipfile .ZipFile (artifact_zip_location , "r" ) as zip_ref :
402
- artifact_unzipped_location = os . path . join ( build_dir , artifact_name )
400
+ artifact_unzipped_location = build_dir / artifact_name
403
401
# clean up any remnants of previous runs
404
402
shutil .rmtree (artifact_unzipped_location , ignore_errors = True )
405
403
# And then we extract it to build_dir/artifact_name
406
404
zip_ref .extractall (artifact_unzipped_location )
407
405
# And then we extract the language tar.gz file inside it
408
- artifact_tar_location = os .path .join (
409
- artifact_unzipped_location , f"{ language } .tar.gz"
410
- )
406
+ artifact_tar_location = artifact_unzipped_location / f"{ language } .tar.gz"
411
407
with tarfile .open (artifact_tar_location , "r:gz" ) as tar_ref :
412
408
# And we just untar it to the same directory as the zip file
413
409
tar_ref .extractall (artifact_unzipped_location )
414
- ret = os . path . join ( artifact_unzipped_location , language )
410
+ ret = artifact_unzipped_location / language
415
411
print (f"Decompression complete: { ret } " )
416
412
return ret
417
413
@@ -431,8 +427,16 @@ def download_and_decompress(analyzed_database: dict) -> str:
431
427
return [(project_map [n ], r ) for n , r in zip (analyzed_databases , results )]
432
428
433
429
434
- def get_mad_destination_for_project (config , name : str ) -> str :
435
- return os .path .join (config ["destination" ], name )
430
+ def clean_up_mad_destination_for_project (config , name : str ):
431
+ target = pathlib .Path (config ["destination" ], name )
432
+ if config .get ("single-file" , False ):
433
+ target = target .with_suffix (".model.yml" )
434
+ if target .exists ():
435
+ print (f"Deleting existing MaD file at { target } " )
436
+ target .unlink ()
437
+ elif target .exists ():
438
+ print (f"Deleting existing MaD directory at { target } " )
439
+ shutil .rmtree (target , ignore_errors = True )
436
440
437
441
438
442
def get_strategy (config ) -> str :
@@ -454,8 +458,7 @@ def main(config, args) -> None:
454
458
language = config ["language" ]
455
459
456
460
# Create build directory if it doesn't exist
457
- if not os .path .exists (build_dir ):
458
- os .makedirs (build_dir )
461
+ build_dir .mkdir (parents = True , exist_ok = True )
459
462
460
463
database_results = []
461
464
match get_strategy (config ):
@@ -475,7 +478,7 @@ def main(config, args) -> None:
475
478
if args .pat is None :
476
479
print ("ERROR: --pat argument is required for DCA strategy" )
477
480
sys .exit (1 )
478
- if not os . path .exists (args . pat ):
481
+ if not args . pat .exists ():
479
482
print (f"ERROR: Personal Access Token file '{ pat } ' does not exist." )
480
483
sys .exit (1 )
481
484
with open (args .pat , "r" ) as f :
@@ -499,12 +502,9 @@ def main(config, args) -> None:
499
502
)
500
503
sys .exit (1 )
501
504
502
- # Delete the MaD directory for each project
503
- for project , database_dir in database_results :
504
- mad_dir = get_mad_destination_for_project (config , project ["name" ])
505
- if os .path .exists (mad_dir ):
506
- print (f"Deleting existing MaD directory at { mad_dir } " )
507
- subprocess .check_call (["rm" , "-rf" , mad_dir ])
505
+ # clean up existing MaD data for the projects
506
+ for project , _ in database_results :
507
+ clean_up_mad_destination_for_project (config , project ["name" ])
508
508
509
509
for project , database_dir in database_results :
510
510
if database_dir is not None :
@@ -514,7 +514,10 @@ def main(config, args) -> None:
514
514
if __name__ == "__main__" :
515
515
parser = argparse .ArgumentParser ()
516
516
parser .add_argument (
517
- "--config" , type = str , help = "Path to the configuration file." , required = True
517
+ "--config" ,
518
+ type = pathlib .Path ,
519
+ help = "Path to the configuration file." ,
520
+ required = True ,
518
521
)
519
522
parser .add_argument (
520
523
"--dca" ,
@@ -525,7 +528,7 @@ def main(config, args) -> None:
525
528
)
526
529
parser .add_argument (
527
530
"--pat" ,
528
- type = str ,
531
+ type = pathlib . Path ,
529
532
help = "Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)" ,
530
533
)
531
534
parser .add_argument (
@@ -544,7 +547,7 @@ def main(config, args) -> None:
544
547
545
548
# Load config file
546
549
config = {}
547
- if not os . path .exists (args . config ):
550
+ if not args . config .exists ():
548
551
print (f"ERROR: Config file '{ args .config } ' does not exist." )
549
552
sys .exit (1 )
550
553
try :
0 commit comments