36
36
logging .basicConfig (level = logging .DEBUG )
37
37
38
38
# database defaults
39
+ DISK_LOCATION_BACKUP = os .path .join (
40
+ os .path .expanduser ("~" ), ".cache" , "cve-bin-tool-backup"
41
+ )
39
42
DISK_LOCATION_DEFAULT = os .path .join (os .path .expanduser ("~" ), ".cache" , "cve-bin-tool" )
40
43
DBNAME = "cve.db"
41
44
OLD_CACHE_DIR = os .path .join (os .path .expanduser ("~" ), ".cache" , "cvedb" )
@@ -47,6 +50,7 @@ class CVEDB:
47
50
"""
48
51
49
52
CACHEDIR = DISK_LOCATION_DEFAULT
53
+ BACKUPCACHEDIR = DISK_LOCATION_BACKUP
50
54
FEED = "https://nvd.nist.gov/vuln/data-feeds"
51
55
LOGGER = LOGGER .getChild ("CVEDB" )
52
56
NVDCVE_FILENAME_TEMPLATE = "nvdcve-1.1-{}.json.gz"
@@ -59,12 +63,16 @@ def __init__(
59
63
self ,
60
64
feed = None ,
61
65
cachedir = None ,
66
+ backup_cachedir = None ,
62
67
version_check = True ,
63
68
session = None ,
64
69
error_mode = ErrorMode .TruncTrace ,
65
70
):
66
71
self .feed = feed if feed is not None else self .FEED
67
72
self .cachedir = cachedir if cachedir is not None else self .CACHEDIR
73
+ self .backup_cachedir = (
74
+ backup_cachedir if backup_cachedir is not None else self .BACKUPCACHEDIR
75
+ )
68
76
self .error_mode = error_mode
69
77
# Will be true if refresh was successful
70
78
self .was_updated = False
@@ -78,6 +86,9 @@ def __init__(
78
86
self .session = session
79
87
self .cve_count = - 1
80
88
89
+ if not os .path .exists (self .dbpath ):
90
+ self .rollback_cache_backup ()
91
+
81
92
def get_cve_count (self ):
82
93
if self .cve_count == - 1 :
83
94
# Force update
@@ -618,8 +629,9 @@ def curl_versions(self):
618
629
]
619
630
620
631
def clear_cached_data (self ):
632
+ self .create_cache_backup ()
621
633
if os .path .exists (self .cachedir ):
622
- self .LOGGER .warning (f"Deleting cachedir { self .cachedir } " )
634
+ self .LOGGER .warning (f"Updating cachedir { self .cachedir } " )
623
635
shutil .rmtree (self .cachedir )
624
636
# Remove files associated with pre-1.0 development tree
625
637
if os .path .exists (OLD_CACHE_DIR ):
@@ -636,3 +648,29 @@ def db_close(self):
636
648
if self .connection :
637
649
self .connection .close ()
638
650
self .connection = None
651
+
652
+ def create_cache_backup (self ):
653
+ """Creates a backup of the cachedir in case anything fails"""
654
+ if os .path .exists (self .cachedir ):
655
+ self .LOGGER .debug (
656
+ f"Creating backup of cachedir { self .cachedir } at { self .backup_cachedir } "
657
+ )
658
+ self .remove_cache_backup ()
659
+ shutil .copytree (self .cachedir , self .backup_cachedir )
660
+
661
+ def remove_cache_backup (self ):
662
+ """Removes the backup if database was successfully loaded"""
663
+ if os .path .exists (self .backup_cachedir ):
664
+ self .LOGGER .debug (f"Removing backup cache from { self .backup_cachedir } " )
665
+ shutil .rmtree (self .backup_cachedir )
666
+
667
+ def rollback_cache_backup (self ):
668
+ """Rollback the cachedir backup in case anything fails"""
669
+ if os .path .exists (os .path .join (self .backup_cachedir , DBNAME )):
670
+ self .LOGGER .info (f"Rolling back the cache to its previous state" )
671
+ if os .path .exists (self .cachedir ):
672
+ shutil .rmtree (self .cachedir )
673
+ shutil .move (self .backup_cachedir , self .cachedir )
674
+
675
+ def __del__ (self ):
676
+ self .rollback_cache_backup ()
0 commit comments