Skip to content

Commit af7abb6

Browse files
authored
Unittets for database files (#1077)
* unittests for database files * ruff fixes
1 parent 4fd743a commit af7abb6

File tree

4 files changed

+369
-0
lines changed

4 files changed

+369
-0
lines changed

tests/database/test_models.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from datetime import datetime
2+
3+
from sqlalchemy import create_engine
4+
from sqlalchemy.orm import sessionmaker
5+
6+
from nettacker.database.models import Base, Report, TempEvents, HostsLog
7+
from tests.common import TestCase
8+
9+
10+
class TestModels(TestCase):
11+
def setUp(self):
12+
# Creating an in-memory SQLite database for testing
13+
self.engine = create_engine("sqlite:///:memory:")
14+
Base.metadata.create_all(self.engine)
15+
Session = sessionmaker(bind=self.engine)
16+
self.session = Session()
17+
18+
def tearDown(self):
19+
self.session.close()
20+
Base.metadata.drop_all(self.engine)
21+
22+
def test_report_model(self):
23+
test_date = datetime.now()
24+
test_report = Report(
25+
date=test_date,
26+
scan_unique_id="test123",
27+
report_path_filename="/path/to/report.txt",
28+
options='{"option1": "value1"}',
29+
)
30+
31+
self.session.add(test_report)
32+
self.session.commit()
33+
34+
retrieved_report = self.session.query(Report).first()
35+
self.assertIsNotNone(retrieved_report)
36+
self.assertEqual(retrieved_report.scan_unique_id, "test123")
37+
self.assertEqual(retrieved_report.report_path_filename, "/path/to/report.txt")
38+
self.assertEqual(retrieved_report.options, '{"option1": "value1"}')
39+
40+
repr_string = repr(retrieved_report)
41+
self.assertIn("test123", repr_string)
42+
self.assertIn("/path/to/report.txt", repr_string)
43+
44+
def test_temp_events_model(self):
45+
test_date = datetime.now()
46+
test_event = TempEvents(
47+
date=test_date,
48+
target="192.168.1.1",
49+
module_name="port_scan",
50+
scan_unique_id="test123",
51+
event_name="open_port",
52+
port="80",
53+
event="Port 80 is open",
54+
data='{"details": "HTTP server running"}',
55+
)
56+
57+
self.session.add(test_event)
58+
self.session.commit()
59+
60+
retrieved_event = self.session.query(TempEvents).first()
61+
self.assertIsNotNone(retrieved_event)
62+
self.assertEqual(retrieved_event.target, "192.168.1.1")
63+
self.assertEqual(retrieved_event.module_name, "port_scan")
64+
self.assertEqual(retrieved_event.port, "80")
65+
66+
repr_string = repr(retrieved_event)
67+
self.assertIn("192.168.1.1", repr_string)
68+
self.assertIn("port_scan", repr_string)
69+
70+
def test_hosts_log_model(self):
71+
test_date = datetime.now()
72+
test_log = HostsLog(
73+
date=test_date,
74+
target="192.168.1.1",
75+
module_name="vulnerability_scan",
76+
scan_unique_id="test123",
77+
port="443",
78+
event="Found vulnerability CVE-2021-12345",
79+
json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}',
80+
)
81+
82+
self.session.add(test_log)
83+
self.session.commit()
84+
85+
retrieved_log = self.session.query(HostsLog).first()
86+
self.assertIsNotNone(retrieved_log)
87+
self.assertEqual(retrieved_log.target, "192.168.1.1")
88+
self.assertEqual(retrieved_log.module_name, "vulnerability_scan")
89+
self.assertEqual(retrieved_log.port, "443")
90+
self.assertEqual(retrieved_log.event, "Found vulnerability CVE-2021-12345")
91+
92+
repr_string = repr(retrieved_log)
93+
self.assertIn("192.168.1.1", repr_string)
94+
self.assertIn("vulnerability_scan", repr_string)

tests/database/test_mysql.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from unittest.mock import patch, MagicMock
2+
3+
from sqlalchemy.exc import SQLAlchemyError
4+
5+
from nettacker.config import Config
6+
from nettacker.database.models import Base
7+
from nettacker.database.mysql import mysql_create_database, mysql_create_tables
8+
from tests.common import TestCase
9+
10+
11+
class TestMySQLFunctions(TestCase):
12+
"""Test cases for mysql.py functions"""
13+
14+
@patch("nettacker.database.mysql.create_engine")
15+
def test_mysql_create_database_success(self, mock_create_engine):
16+
"""Test successful database creation"""
17+
# Set up mock config
18+
Config.db = MagicMock()
19+
Config.db.as_dict.return_value = {
20+
"username": "test_user",
21+
"password": "test_pass",
22+
"host": "localhost",
23+
"port": "3306",
24+
"name": "test_db",
25+
}
26+
Config.db.name = "test_db"
27+
28+
# Set up mock connection and execution
29+
mock_conn = MagicMock()
30+
mock_engine = MagicMock()
31+
mock_create_engine.return_value = mock_engine
32+
mock_engine.connect.return_value.__enter__.return_value = mock_conn
33+
34+
# Mock database query results - database doesn't exist yet
35+
mock_conn.execute.return_value = [("mysql",), ("information_schema",)]
36+
37+
# Call the function
38+
mysql_create_database()
39+
40+
# Assertions
41+
mock_create_engine.assert_called_once_with(
42+
"mysql+pymysql://test_user:test_pass@localhost:3306"
43+
)
44+
45+
# Check that execute was called with any text object that has the expected SQL
46+
call_args_list = mock_conn.execute.call_args_list
47+
self.assertEqual(len(call_args_list), 2) # Two calls to execute
48+
49+
# Check that the first call is SHOW DATABASES
50+
first_call_arg = call_args_list[0][0][0]
51+
self.assertEqual(str(first_call_arg), "SHOW DATABASES;")
52+
53+
# Check that the second call is CREATE DATABASE
54+
second_call_arg = call_args_list[1][0][0]
55+
self.assertEqual(str(second_call_arg), "CREATE DATABASE test_db ")
56+
57+
@patch("nettacker.database.mysql.create_engine")
58+
def test_mysql_create_database_already_exists(self, mock_create_engine):
59+
"""Test when database already exists"""
60+
# Set up mock config
61+
Config.db = MagicMock()
62+
Config.db.as_dict.return_value = {
63+
"username": "test_user",
64+
"password": "test_pass",
65+
"host": "localhost",
66+
"port": "3306",
67+
"name": "test_db",
68+
}
69+
Config.db.name = "test_db"
70+
71+
# Set up mock connection and execution
72+
mock_conn = MagicMock()
73+
mock_engine = MagicMock()
74+
mock_create_engine.return_value = mock_engine
75+
mock_engine.connect.return_value.__enter__.return_value = mock_conn
76+
77+
# Mock database query results - database already exists
78+
mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]
79+
80+
# Call the function
81+
mysql_create_database()
82+
83+
# Assertions
84+
mock_create_engine.assert_called_once_with(
85+
"mysql+pymysql://test_user:test_pass@localhost:3306"
86+
)
87+
88+
# Check that execute was called once with SHOW DATABASES
89+
self.assertEqual(mock_conn.execute.call_count, 1)
90+
call_arg = mock_conn.execute.call_args[0][0]
91+
self.assertEqual(str(call_arg), "SHOW DATABASES;")
92+
93+
@patch("nettacker.database.mysql.create_engine")
94+
def test_mysql_create_database_exception(self, mock_create_engine):
95+
"""Test exception handling in create database"""
96+
# Set up mock config
97+
Config.db = MagicMock()
98+
Config.db.as_dict.return_value = {
99+
"username": "test_user",
100+
"password": "test_pass",
101+
"host": "localhost",
102+
"port": "3306",
103+
"name": "test_db",
104+
}
105+
106+
# Set up mock to raise exception
107+
mock_engine = MagicMock()
108+
mock_create_engine.return_value = mock_engine
109+
mock_engine.connect.side_effect = SQLAlchemyError("Connection error")
110+
111+
# Call the function (should not raise exception)
112+
with patch("builtins.print") as mock_print:
113+
mysql_create_database()
114+
mock_print.assert_called_once()
115+
116+
@patch("nettacker.database.mysql.create_engine")
117+
def test_mysql_create_tables(self, mock_create_engine):
118+
"""Test table creation function"""
119+
# Set up mock config
120+
Config.db = MagicMock()
121+
Config.db.as_dict.return_value = {
122+
"username": "test_user",
123+
"password": "test_pass",
124+
"host": "localhost",
125+
"port": "3306",
126+
"name": "test_db",
127+
}
128+
129+
# Set up mock engine
130+
mock_engine = MagicMock()
131+
mock_create_engine.return_value = mock_engine
132+
133+
# Call the function
134+
with patch.object(Base.metadata, "create_all") as mock_create_all:
135+
mysql_create_tables()
136+
137+
# Assertions
138+
mock_create_engine.assert_called_once_with(
139+
"mysql+pymysql://test_user:test_pass@localhost:3306/test_db"
140+
)
141+
mock_create_all.assert_called_once_with(mock_engine)

tests/database/test_postgresql.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from unittest.mock import patch, MagicMock
2+
3+
from sqlalchemy.exc import OperationalError
4+
5+
from nettacker.config import Config
6+
from nettacker.database.models import Base
7+
from nettacker.database.postgresql import postgres_create_database
8+
from tests.common import TestCase
9+
10+
11+
class TestPostgresFunctions(TestCase):
12+
@patch("nettacker.database.postgresql.create_engine")
13+
def test_postgres_create_database_success(self, mock_create_engine):
14+
Config.db = MagicMock()
15+
Config.db.as_dict.return_value = {
16+
"username": "user",
17+
"password": "pass",
18+
"host": "localhost",
19+
"port": "5432",
20+
"name": "nettacker_db",
21+
}
22+
23+
mock_engine = MagicMock()
24+
mock_create_engine.return_value = mock_engine
25+
26+
with patch.object(Base.metadata, "create_all") as mock_create_all:
27+
postgres_create_database()
28+
29+
mock_create_engine.assert_called_once_with(
30+
"postgresql+psycopg2://user:pass@localhost:5432/nettacker_db"
31+
)
32+
mock_create_all.assert_called_once_with(mock_engine)
33+
34+
@patch("nettacker.database.postgresql.create_engine")
35+
def test_postgres_create_database_if_not_exists(self, mock_create_engine):
36+
Config.db = MagicMock()
37+
Config.db.as_dict.return_value = {
38+
"username": "user",
39+
"password": "pass",
40+
"host": "localhost",
41+
"port": "5432",
42+
"name": "nettacker_db",
43+
}
44+
Config.db.name = "nettacker_db"
45+
46+
mock_engine_initial = MagicMock()
47+
mock_engine_fallback = MagicMock()
48+
mock_engine_final = MagicMock()
49+
50+
mock_create_engine.side_effect = [
51+
mock_engine_initial,
52+
mock_engine_fallback,
53+
mock_engine_final,
54+
]
55+
56+
with patch.object(
57+
Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None]
58+
):
59+
mock_conn = MagicMock()
60+
mock_engine_fallback.connect.return_value = mock_conn
61+
mock_conn.execution_options.return_value = mock_conn
62+
63+
postgres_create_database()
64+
65+
assert mock_create_engine.call_count == 3
66+
args, _ = mock_conn.execute.call_args
67+
assert str(args[0]) == "CREATE DATABASE nettacker_db"
68+
mock_conn.close.assert_called_once()
69+
70+
@patch("nettacker.database.postgresql.create_engine")
71+
def test_postgres_create_database_create_fail(self, mock_create_engine):
72+
Config.db = MagicMock()
73+
Config.db.as_dict.return_value = {
74+
"username": "user",
75+
"password": "pass",
76+
"host": "localhost",
77+
"port": "5432",
78+
"name": "nettacker_db",
79+
}
80+
81+
mock_engine_initial = MagicMock()
82+
mock_engine_fallback = MagicMock()
83+
84+
mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback]
85+
86+
mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None)
87+
88+
with patch.object(
89+
Base.metadata, "create_all", side_effect=OperationalError("fail", None, None)
90+
):
91+
with self.assertRaises(OperationalError):
92+
postgres_create_database()

tests/database/test_sqlite.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from unittest.mock import patch, MagicMock
2+
3+
from sqlalchemy import create_engine, inspect
4+
5+
from nettacker.config import Config
6+
from nettacker.database.models import Base
7+
from nettacker.database.sqlite import sqlite_create_tables
8+
from tests.common import TestCase
9+
10+
11+
class TestSQLiteFunctions(TestCase):
12+
@patch("nettacker.database.sqlite.create_engine")
13+
def test_sqlite_create_tables(self, mock_create_engine):
14+
Config.db = MagicMock()
15+
Config.db.as_dict.return_value = {"name": "/path/to/test.db"}
16+
17+
mock_engine = MagicMock()
18+
mock_create_engine.return_value = mock_engine
19+
20+
with patch.object(Base.metadata, "create_all") as mock_create_all:
21+
sqlite_create_tables()
22+
23+
mock_create_engine.assert_called_once_with(
24+
"sqlite:////path/to/test.db", connect_args={"check_same_thread": False}
25+
)
26+
mock_create_all.assert_called_once_with(mock_engine)
27+
28+
def test_sqlite_create_tables_integration(self):
29+
engine = create_engine("sqlite:///:memory:")
30+
31+
Config.db = MagicMock()
32+
Config.db.as_dict.return_value = {"name": ":memory:"}
33+
34+
with patch("nettacker.database.sqlite.create_engine", return_value=engine):
35+
sqlite_create_tables()
36+
37+
inspector = inspect(engine)
38+
tables = inspector.get_table_names()
39+
40+
self.assertIn("reports", tables, "Reports table was not created")
41+
self.assertIn("temp_events", tables, "Temp events table was not created")
42+
self.assertIn("scan_events", tables, "Scan events table was not created")

0 commit comments

Comments
 (0)