Skip to content

Commit 6a9c6cb

Browse files
committed
fixing test_db.py, the other one is still failing because of some mocking issues
1 parent b1b9851 commit 6a9c6cb

File tree

5 files changed

+102
-75
lines changed

5 files changed

+102
-75
lines changed

nettacker/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class DbConfig(ConfigBase):
8383
fill the name of the DB as sqlite,
8484
DATABASE as the name of the db user wants
8585
Set the journal_mode (default="WAL") and
86-
synchronous_mode (deafault="NORMAL"). Rest
86+
synchronous_mode (default="NORMAL"). Rest
8787
of the fields can be left emptyAdd commentMore actions
8888
This is the default database:
8989
str(CWD / ".data/nettacker.db")

nettacker/database/db.py

Lines changed: 68 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import json
22
import time
33

4-
import apsw
4+
try:
5+
import apsw
6+
except ImportError:
7+
apsw = None
8+
59
from sqlalchemy import create_engine
610
from sqlalchemy.orm import sessionmaker
711

@@ -12,7 +16,7 @@
1216
from nettacker.database.models import HostsLog, Report, TempEvents
1317

1418
config = Config()
15-
logging = logger.get_logger()
19+
logger = logger.get_logger()
1620

1721

1822
def db_inputs(connection_type):
@@ -44,17 +48,24 @@ def create_connection():
4448
connection failed.
4549
"""
4650
if Config.db.engine.startswith("sqlite"):
51+
if apsw is None:
52+
raise ImportError("APSW is required for SQLite backend.")
4753
# In case of sqlite, the name parameter is the database path
48-
DB_PATH = config.db.as_dict()["name"]
49-
connection = apsw.Connection(DB_PATH)
50-
connection.setbusytimeout(int(config.settings.timeout) * 100)
51-
cursor = connection.cursor()
5254

53-
# Performance enhancing configurations. Put WAL cause that helps with concurrency
54-
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}")
55-
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
55+
try:
56+
DB_PATH = config.db.as_dict()["name"]
57+
connection = apsw.Connection(DB_PATH)
58+
connection.setbusytimeout(int(config.settings.timeout) * 100)
59+
cursor = connection.cursor()
60+
61+
# Performance enhancing configurations. Put WAL cause that helps with concurrency
62+
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}")
63+
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
5664

57-
return connection, cursor
65+
return connection, cursor
66+
except Exception as e:
67+
logger.error(f"Failed to create APSW connection: {e}")
68+
raise
5869

5970
else:
6071
# Both MySQL and PostgreSQL don't need a
@@ -94,7 +105,7 @@ def send_submit_query(session):
94105
finally:
95106
connection.close()
96107
connection.close()
97-
logging.warn(messages("database_connect_fail"))
108+
logger.warn(messages("database_connect_fail"))
98109
return False
99110
else:
100111
try:
@@ -104,10 +115,10 @@ def send_submit_query(session):
104115
return True
105116
except Exception:
106117
time.sleep(0.1)
107-
logging.warn(messages("database_connect_fail"))
118+
logger.warn(messages("database_connect_fail"))
108119
return False
109120
except Exception:
110-
logging.warn(messages("database_connect_fail"))
121+
logger.warn(messages("database_connect_fail"))
111122
return False
112123
return False
113124

@@ -123,7 +134,7 @@ def submit_report_to_db(event):
123134
Returns:
124135
return True if submitted otherwise False
125136
"""
126-
logging.verbose_info(messages("inserting_report_db"))
137+
logger.verbose_info(messages("inserting_report_db"))
127138
session = create_connection()
128139

129140
if isinstance(session, tuple):
@@ -146,7 +157,7 @@ def submit_report_to_db(event):
146157
return send_submit_query(session)
147158
except Exception:
148159
cursor.execute("ROLLBACK")
149-
logging.warn("Could not insert report...")
160+
logger.warn("Could not insert report...")
150161
return False
151162
finally:
152163
cursor.close()
@@ -197,7 +208,7 @@ def remove_old_logs(options):
197208
return send_submit_query(session)
198209
except Exception:
199210
cursor.execute("ROLLBACK")
200-
logging.warn("Could not remove old logs...")
211+
logger.warn("Could not remove old logs...")
201212
return False
202213
finally:
203214
cursor.close()
@@ -253,7 +264,7 @@ def submit_logs_to_db(log):
253264

254265
except apsw.BusyError as e:
255266
if "database is locked" in str(e).lower():
256-
logging.warn(
267+
logger.warn(
257268
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
258269
)
259270
if connection.in_transaction:
@@ -272,7 +283,7 @@ def submit_logs_to_db(log):
272283
pass
273284
return False
274285
# All retires exhausted but we want to continue operation
275-
logging.warn("All retries exhausted. Skipping this log.")
286+
logger.warn("All retries exhausted. Skipping this log.")
276287
return True
277288
finally:
278289
cursor.close()
@@ -290,7 +301,7 @@ def submit_logs_to_db(log):
290301
)
291302
return send_submit_query(session)
292303
else:
293-
logging.warn(messages("invalid_json_type_to_db").format(log))
304+
logger.warn(messages("invalid_json_type_to_db").format(log))
294305
return False
295306

296307

@@ -335,7 +346,7 @@ def submit_temp_logs_to_db(log):
335346
return send_submit_query(session)
336347
except apsw.BusyError as e:
337348
if "database is locked" in str(e).lower():
338-
logging.warn(
349+
logger.warn(
339350
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
340351
)
341352
try:
@@ -360,7 +371,7 @@ def submit_temp_logs_to_db(log):
360371
pass
361372
return False
362373
# All retires exhausted but we want to continue operation
363-
logging.warn("All retries exhausted. Skipping this log.")
374+
logger.warn("All retries exhausted. Skipping this log.")
364375
return True
365376
finally:
366377
cursor.close()
@@ -379,7 +390,7 @@ def submit_temp_logs_to_db(log):
379390
)
380391
return send_submit_query(session)
381392
else:
382-
logging.warn(messages("invalid_json_type_to_db").format(log))
393+
logger.warn(messages("invalid_json_type_to_db").format(log))
383394
return False
384395

385396

@@ -400,28 +411,23 @@ def find_temp_events(target, module_name, scan_id, event_name):
400411
if isinstance(session, tuple):
401412
connection, cursor = session
402413
try:
403-
for _ in range(100):
404-
try:
405-
cursor.execute(
406-
"""
407-
SELECT event
408-
FROM temp_events
409-
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
410-
LIMIT 1
411-
""",
412-
(target, module_name, scan_id, event_name),
413-
)
414+
cursor.execute(
415+
"""
416+
SELECT event
417+
FROM temp_events
418+
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
419+
LIMIT 1
420+
""",
421+
(target, module_name, scan_id, event_name),
422+
)
414423

415-
row = cursor.fetchone()
416-
cursor.close()
417-
if row:
418-
return json.loads(row[0])
419-
return []
420-
except Exception:
421-
logging.warn("Database query failed...")
422-
return []
424+
row = cursor.fetchone()
425+
cursor.close()
426+
if row:
427+
return json.loads(row[0])
428+
return []
423429
except Exception:
424-
logging.warn(messages("database_connect_fail"))
430+
logger.warn(messages("database_connect_fail"))
425431
return []
426432
return []
427433
else:
@@ -441,7 +447,7 @@ def find_temp_events(target, module_name, scan_id, event_name):
441447
except Exception:
442448
time.sleep(0.1)
443449
except Exception:
444-
logging.warn(messages("database_connect_fail"))
450+
logger.warn(messages("database_connect_fail"))
445451
return []
446452
return []
447453

@@ -477,7 +483,7 @@ def find_events(target, module_name, scan_id):
477483
return [json.dumps((json.loads(row[0]))) for row in rows]
478484
return []
479485
except Exception:
480-
logging.warn("Database query failed...")
486+
logger.warn("Database query failed...")
481487
return []
482488
else:
483489
return [
@@ -536,7 +542,7 @@ def select_reports(page):
536542
return selected
537543

538544
except Exception:
539-
logging.warn("Could not retrieve report...")
545+
logger.warn("Could not retrieve report...")
540546
return structure(status="error", msg="database error!")
541547
else:
542548
try:
@@ -582,13 +588,23 @@ def get_scan_result(id):
582588
cursor.close()
583589
if row:
584590
filename = row[0]
585-
return filename, open(str(filename), "rb").read()
591+
try:
592+
return filename, open(str(filename), "rb").read()
593+
except IOError as e:
594+
logger.error(f"Failed to read report file: {e}")
595+
return None
586596
else:
587597
return structure(status="error", msg="database error!")
588598
else:
589-
filename = session.query(Report).filter_by(id=id).first().report_path_filename
599+
report = session.query(Report).filter_by(id=id).first()
600+
if not report:
601+
return None
590602

591-
return filename, open(str(filename), "rb").read()
603+
try:
604+
return report.report_path_filename, open(str(report.report_path_filename), "rb").read()
605+
except IOError as e:
606+
logger.error(f"Failed to read report file: {e}")
607+
return None
592608

593609

594610
def last_host_logs(page):
@@ -656,7 +672,6 @@ def last_host_logs(page):
656672
)
657673
events = [row[0] for row in cursor.fetchall()]
658674

659-
cursor.close()
660675
hosts.append(
661676
{
662677
"target": target,
@@ -667,11 +682,11 @@ def last_host_logs(page):
667682
},
668683
}
669684
)
670-
685+
cursor.close()
671686
return hosts
672687

673688
except Exception:
674-
logging.warn("Database query failed...")
689+
logger.warn("Database query failed...")
675690
return structure(status="error", msg="Database error!")
676691

677692
else:
@@ -834,7 +849,7 @@ def logs_to_report_json(target):
834849
"event": json.loads(log[3]),
835850
"json_event": json.loads(log[4]),
836851
}
837-
return_logs.append(data)
852+
return_logs.append(data)
838853
return return_logs
839854

840855
else:
@@ -1001,7 +1016,6 @@ def search_logs(page, query):
10011016
),
10021017
)
10031018
targets = cursor.fetchall()
1004-
cursor.close()
10051019
for target_row in targets:
10061020
target = target_row[0]
10071021
# Fetch data for each target grouped by key fields
@@ -1044,6 +1058,7 @@ def search_logs(page, query):
10441058
tmp["info"]["json_event"].append(parsed_json_event)
10451059

10461060
selected.append(tmp)
1061+
cursor.close()
10471062

10481063
except Exception:
10491064
return structure(status="error", msg="database error!")

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ profile = "black"
9494
[tool.pytest.ini_options]
9595
addopts = "--cov=nettacker --cov-config=pyproject.toml --cov-report term --cov-report xml --dist loadscope --no-cov-on-fail --numprocesses auto"
9696
testpaths = ["tests"]
97+
markers = [
98+
"asyncio: mark test as async"
99+
]
97100

98101
[tool.ruff]
99102
line-length = 99

tests/core/test_exclude_ports.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ def test_load_with_service_discovery(
5959
}
6060
mock_loader.return_value = mock_loader_inst
6161

62-
mock_find_events.return_value = [
63-
MagicMock(json_event='{"port": 80, "response": {"conditions_results": {"http": {}}}}')
64-
]
62+
mock_event1 = MagicMock()
63+
mock_event1.json_event = json.dumps(
64+
{"port": 80, "response": {"conditions_results": {"http": {}}}}
65+
)
66+
67+
mock_find_events.return_value = [mock_event1]
6568

6669
module = Module("test_module", options, **module_args)
6770
module.load()

0 commit comments

Comments
 (0)