Skip to content

Commit f0f4a9d

Browse files
authored
feat (cvedb): rollback cachedir if cvedb refresh fails (#1225)
* feat (cvedb): rollback cachedir if cvedb refresh fails * pass backup_cachedir argument in cvedb * update test_cli.py log assertion
1 parent c9620ad commit f0f4a9d

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

cve_bin_tool/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ def main(argv=None):
308308
with ErrorHandler(mode=error_mode, logger=LOGGER):
309309
raise CVEDataMissing("No data in CVE Database")
310310

311+
cvedb_orig.remove_cache_backup()
312+
311313
# Input validation
312314
if not args["directory"] and not args["input_file"] and not args["package_list"]:
313315
parser.print_usage()

cve_bin_tool/cvedb.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
logging.basicConfig(level=logging.DEBUG)
3737

3838
# database defaults
39+
DISK_LOCATION_BACKUP = os.path.join(
40+
os.path.expanduser("~"), ".cache", "cve-bin-tool-backup"
41+
)
3942
DISK_LOCATION_DEFAULT = os.path.join(os.path.expanduser("~"), ".cache", "cve-bin-tool")
4043
DBNAME = "cve.db"
4144
OLD_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "cvedb")
@@ -47,6 +50,7 @@ class CVEDB:
4750
"""
4851

4952
CACHEDIR = DISK_LOCATION_DEFAULT
53+
BACKUPCACHEDIR = DISK_LOCATION_BACKUP
5054
FEED = "https://nvd.nist.gov/vuln/data-feeds"
5155
LOGGER = LOGGER.getChild("CVEDB")
5256
NVDCVE_FILENAME_TEMPLATE = "nvdcve-1.1-{}.json.gz"
@@ -59,12 +63,16 @@ def __init__(
5963
self,
6064
feed=None,
6165
cachedir=None,
66+
backup_cachedir=None,
6267
version_check=True,
6368
session=None,
6469
error_mode=ErrorMode.TruncTrace,
6570
):
6671
self.feed = feed if feed is not None else self.FEED
6772
self.cachedir = cachedir if cachedir is not None else self.CACHEDIR
73+
self.backup_cachedir = (
74+
backup_cachedir if backup_cachedir is not None else self.BACKUPCACHEDIR
75+
)
6876
self.error_mode = error_mode
6977
# Will be true if refresh was successful
7078
self.was_updated = False
@@ -78,6 +86,9 @@ def __init__(
7886
self.session = session
7987
self.cve_count = -1
8088

89+
if not os.path.exists(self.dbpath):
90+
self.rollback_cache_backup()
91+
8192
def get_cve_count(self):
8293
if self.cve_count == -1:
8394
# Force update
@@ -618,8 +629,9 @@ def curl_versions(self):
618629
]
619630

620631
def clear_cached_data(self):
632+
self.create_cache_backup()
621633
if os.path.exists(self.cachedir):
622-
self.LOGGER.warning(f"Deleting cachedir {self.cachedir}")
634+
self.LOGGER.warning(f"Updating cachedir {self.cachedir}")
623635
shutil.rmtree(self.cachedir)
624636
# Remove files associated with pre-1.0 development tree
625637
if os.path.exists(OLD_CACHE_DIR):
@@ -636,3 +648,29 @@ def db_close(self):
636648
if self.connection:
637649
self.connection.close()
638650
self.connection = None
651+
652+
def create_cache_backup(self):
653+
"""Creates a backup of the cachedir in case anything fails"""
654+
if os.path.exists(self.cachedir):
655+
self.LOGGER.debug(
656+
f"Creating backup of cachedir {self.cachedir} at {self.backup_cachedir}"
657+
)
658+
self.remove_cache_backup()
659+
shutil.copytree(self.cachedir, self.backup_cachedir)
660+
661+
def remove_cache_backup(self):
662+
"""Removes the backup if database was successfully loaded"""
663+
if os.path.exists(self.backup_cachedir):
664+
self.LOGGER.debug(f"Removing backup cache from {self.backup_cachedir}")
665+
shutil.rmtree(self.backup_cachedir)
666+
667+
def rollback_cache_backup(self):
668+
"""Rollback the cachedir backup in case anything fails"""
669+
if os.path.exists(os.path.join(self.backup_cachedir, DBNAME)):
670+
self.LOGGER.info(f"Rolling back the cache to its previous state")
671+
if os.path.exists(self.cachedir):
672+
shutil.rmtree(self.cachedir)
673+
shutil.move(self.backup_cachedir, self.cachedir)
674+
675+
def __del__(self):
676+
self.rollback_cache_backup()

test/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_update(self, caplog):
206206
assert (
207207
"cve_bin_tool.CVEDB",
208208
logging.WARNING,
209-
f"Deleting cachedir {db_path}",
209+
f"Updating cachedir {db_path}",
210210
) in caplog.record_tuples and (
211211
"cve_bin_tool.CVEDB",
212212
logging.INFO,

0 commit comments

Comments
 (0)