Skip to content

Commit ba1032b

Browse files
committed
test: add database.py tests (68% coverage)
1 parent 55f8702 commit ba1032b

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed

apps/api/tests/test_database.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""Tests for database.py"""
2+
3+
import pytest
4+
5+
from app.services.database import (
6+
ConnectionTestResult,
7+
DatabaseConfig,
8+
DatabaseManager,
9+
QueryResult,
10+
create_database_manager,
11+
)
12+
13+
14+
class TestDatabaseConfig:
15+
"""Test DatabaseConfig class"""
16+
17+
def test_from_dict_basic(self):
18+
"""Test creating config from dict"""
19+
data = {
20+
"driver": "mysql",
21+
"host": "localhost",
22+
"port": 3306,
23+
"user": "root",
24+
"password": "secret",
25+
"database": "testdb",
26+
}
27+
config = DatabaseConfig.from_dict(data)
28+
assert config.driver == "mysql"
29+
assert config.host == "localhost"
30+
assert config.port == 3306
31+
assert config.user == "root"
32+
assert config.password == "secret"
33+
assert config.database == "testdb"
34+
35+
def test_from_dict_with_username(self):
36+
"""Test creating config with username field"""
37+
data = {
38+
"driver": "postgresql",
39+
"username": "admin",
40+
"database_name": "mydb",
41+
}
42+
config = DatabaseConfig.from_dict(data)
43+
assert config.user == "admin"
44+
assert config.database == "mydb"
45+
46+
def test_from_dict_defaults(self):
47+
"""Test default values"""
48+
data = {"driver": "sqlite"}
49+
config = DatabaseConfig.from_dict(data)
50+
assert config.host == "localhost"
51+
assert config.port is None
52+
assert config.user == ""
53+
assert config.password == ""
54+
55+
def test_get_port_mysql(self):
56+
"""Test MySQL default port"""
57+
config = DatabaseConfig(driver="mysql")
58+
assert config.get_port() == 3306
59+
60+
def test_get_port_postgresql(self):
61+
"""Test PostgreSQL default port"""
62+
config = DatabaseConfig(driver="postgresql")
63+
assert config.get_port() == 5432
64+
65+
def test_get_port_sqlite(self):
66+
"""Test SQLite port"""
67+
config = DatabaseConfig(driver="sqlite")
68+
assert config.get_port() == 0
69+
70+
def test_get_port_custom(self):
71+
"""Test custom port"""
72+
config = DatabaseConfig(driver="mysql", port=3307)
73+
assert config.get_port() == 3307
74+
75+
76+
class TestDatabaseManager:
77+
"""Test DatabaseManager class"""
78+
79+
def test_init_mysql(self):
80+
"""Test MySQL manager initialization"""
81+
config = DatabaseConfig(driver="mysql", host="localhost", database="test")
82+
manager = DatabaseManager(config)
83+
assert manager.config.driver == "mysql"
84+
85+
def test_init_postgresql(self):
86+
"""Test PostgreSQL manager initialization"""
87+
config = DatabaseConfig(driver="postgresql", host="localhost", database="test")
88+
manager = DatabaseManager(config)
89+
assert manager.config.driver == "postgresql"
90+
91+
def test_init_sqlite(self):
92+
"""Test SQLite manager initialization"""
93+
config = DatabaseConfig(driver="sqlite", database=":memory:")
94+
manager = DatabaseManager(config)
95+
assert manager.config.driver == "sqlite"
96+
97+
def test_init_unsupported_driver(self):
98+
"""Test unsupported driver raises error"""
99+
config = DatabaseConfig(driver="oracle", database="test")
100+
with pytest.raises(ValueError, match="不支持的数据库类型"):
101+
DatabaseManager(config)
102+
103+
def test_supported_drivers(self):
104+
"""Test supported drivers list"""
105+
assert "mysql" in DatabaseManager.SUPPORTED_DRIVERS
106+
assert "postgresql" in DatabaseManager.SUPPORTED_DRIVERS
107+
assert "sqlite" in DatabaseManager.SUPPORTED_DRIVERS
108+
109+
def test_read_only_prefixes(self):
110+
"""Test read-only SQL prefixes"""
111+
assert "SELECT" in DatabaseManager.READ_ONLY_PREFIXES
112+
assert "SHOW" in DatabaseManager.READ_ONLY_PREFIXES
113+
assert "DESCRIBE" in DatabaseManager.READ_ONLY_PREFIXES
114+
assert "EXPLAIN" in DatabaseManager.READ_ONLY_PREFIXES
115+
116+
117+
class TestSQLiteManager:
118+
"""Test SQLite specific functionality"""
119+
120+
@pytest.fixture
121+
def sqlite_manager(self, tmp_path):
122+
"""Create SQLite manager with temp file database"""
123+
db_path = tmp_path / "test.db"
124+
config = DatabaseConfig(driver="sqlite", database=str(db_path))
125+
return DatabaseManager(config)
126+
127+
def test_connect(self, sqlite_manager):
128+
"""Test SQLite connection"""
129+
with sqlite_manager.connect() as conn:
130+
assert conn is not None
131+
132+
def test_test_connection(self, sqlite_manager):
133+
"""Test connection test"""
134+
result = sqlite_manager.test_connection()
135+
assert result.connected is True
136+
assert result.version is not None
137+
138+
def test_execute_query(self, sqlite_manager):
139+
"""Test query execution"""
140+
# Create a table first
141+
with sqlite_manager.connect() as conn:
142+
cursor = conn.cursor()
143+
cursor.execute("CREATE TABLE test (id INTEGER, name TEXT)")
144+
cursor.execute("INSERT INTO test VALUES (1, 'Alice')")
145+
cursor.execute("INSERT INTO test VALUES (2, 'Bob')")
146+
conn.commit()
147+
148+
# Query the table
149+
result = sqlite_manager.execute_query("SELECT * FROM test")
150+
assert result.rows_count == 2
151+
assert len(result.data) == 2
152+
153+
def test_execute_query_read_only(self, sqlite_manager):
154+
"""Test read-only mode blocks writes"""
155+
with pytest.raises(ValueError, match="只允许执行只读查询"):
156+
sqlite_manager.execute_query("DROP TABLE test", read_only=True)
157+
158+
def test_get_schema_info(self, sqlite_manager):
159+
"""Test schema info retrieval"""
160+
# Create a table
161+
with sqlite_manager.connect() as conn:
162+
cursor = conn.cursor()
163+
cursor.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
164+
conn.commit()
165+
166+
schema = sqlite_manager.get_schema_info()
167+
assert "users" in schema
168+
169+
170+
class TestQueryResult:
171+
"""Test QueryResult class"""
172+
173+
def test_query_result(self):
174+
"""Test QueryResult creation"""
175+
result = QueryResult(
176+
data=[{"id": 1, "name": "test"}],
177+
rows_count=1,
178+
)
179+
assert result.rows_count == 1
180+
assert len(result.data) == 1
181+
182+
183+
class TestConnectionTestResult:
184+
"""Test ConnectionTestResult class"""
185+
186+
def test_connection_test_result_success(self):
187+
"""Test successful connection result"""
188+
result = ConnectionTestResult(
189+
connected=True,
190+
version="8.0.32",
191+
tables_count=10,
192+
message="Connected successfully",
193+
)
194+
assert result.connected is True
195+
assert result.version == "8.0.32"
196+
197+
def test_connection_test_result_failure(self):
198+
"""Test failed connection result"""
199+
result = ConnectionTestResult(
200+
connected=False,
201+
message="Connection refused",
202+
)
203+
assert result.connected is False
204+
assert result.version is None
205+
206+
207+
class TestCreateDatabaseManager:
208+
"""Test create_database_manager factory function"""
209+
210+
def test_create_from_dict(self):
211+
"""Test creating manager from dict"""
212+
config = {
213+
"driver": "sqlite",
214+
"database": ":memory:",
215+
}
216+
manager = create_database_manager(config)
217+
assert manager.config.driver == "sqlite"
218+
219+
def test_create_mysql(self):
220+
"""Test creating MySQL manager"""
221+
config = {
222+
"driver": "mysql",
223+
"host": "localhost",
224+
"port": 3306,
225+
"user": "root",
226+
"password": "secret",
227+
"database": "testdb",
228+
}
229+
manager = create_database_manager(config)
230+
assert manager.config.driver == "mysql"

0 commit comments

Comments
 (0)