Skip to content

Commit cfe3718

Browse files
committed
added sqlite tests
1 parent 05d5db8 commit cfe3718

File tree

1 file changed

+65
-109
lines changed

1 file changed

+65
-109
lines changed

tests/test_sql_db.py

Lines changed: 65 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,77 @@
11
import os
2-
import sqlite3
2+
import tempfile
33

44
import pytest
5-
from sqlalchemy import INTEGER, TEXT
6-
7-
from rowgen.extract_from_db import DBconnect
8-
9-
10-
@pytest.fixture
11-
def temp_db():
12-
db_path = "test.db"
13-
conn = sqlite3.connect(db_path)
14-
cursor = conn.cursor()
15-
cursor.execute(
16-
"""
17-
CREATE TABLE users (
18-
id INTEGER PRIMARY KEY,
19-
name TEXT NOT NULL,
20-
username TEXT UNIQUE,
21-
rank INTEGER);
22-
"""
5+
from sqlalchemy import (
6+
Column,
7+
Integer,
8+
String,
9+
ForeignKey,
10+
CheckConstraint,
11+
UniqueConstraint,
12+
create_engine,
13+
MetaData,
14+
Table,
15+
)
16+
17+
from rowgen.extract_from_db import (
18+
extract_db_schema,
19+
)
20+
21+
22+
@pytest.fixture(scope="function")
23+
def sqlite_db_url():
24+
return "sqlite:///:memory:"
25+
26+
27+
@pytest.fixture(scope="function")
28+
def setup_database():
29+
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
30+
db_url = f"sqlite:///{tmp.name}"
31+
engine = create_engine(db_url)
32+
metadata = MetaData()
33+
34+
parent = Table(
35+
"parent",
36+
metadata,
37+
Column("id", Integer, primary_key=True, autoincrement=True),
38+
Column("name", String, nullable=False, unique=True),
39+
CheckConstraint("length(name) > 1", name="name_length_check"),
2340
)
24-
conn.commit()
25-
conn.close()
26-
url = f"sqlite:///{db_path}"
27-
yield url
28-
# Cleanup after test
29-
if os.path.exists(db_path):
30-
os.remove(db_path)
31-
32-
33-
@pytest.fixture
34-
def db_connect(temp_db):
35-
dbc = DBconnect(temp_db)
36-
yield dbc
37-
38-
39-
@pytest.fixture
40-
def empty_temp_db():
41-
db_path = "empty_test.db"
42-
conn = sqlite3.connect(db_path)
43-
conn.close()
44-
url = f"sqlite:///{db_path}"
45-
yield url
46-
if os.path.exists(db_path):
47-
os.remove(db_path)
4841

42+
child = Table(
43+
"child",
44+
metadata,
45+
Column("id", Integer, primary_key=True),
46+
Column("parent_id", Integer, ForeignKey("parent.id", ondelete="CASCADE")),
47+
UniqueConstraint("parent_id", name="uq_child_parent_id"),
48+
)
4949

50-
# def test_empty_db_table_columns(empty_temp_db):
51-
# dbc = DBconnect(empty_temp_db)
52-
# assert dbc.table_columns == {}
53-
54-
55-
def test_get_columns(db_connect):
56-
# cols = db_connect.table_columns
57-
expected = {
58-
"users": [
59-
{
60-
"name": "id",
61-
"type": INTEGER(),
62-
"nullable": True,
63-
"default": None,
64-
"primary_key": 1,
65-
},
66-
{
67-
"name": "name",
68-
"type": TEXT(),
69-
"nullable": False,
70-
"default": None,
71-
"primary_key": 0,
72-
},
73-
{
74-
"name": "username",
75-
"type": TEXT(),
76-
"nullable": True,
77-
"default": None,
78-
"primary_key": 0,
79-
},
80-
{
81-
"name": "rank",
82-
"type": INTEGER(),
83-
"nullable": True,
84-
"default": None,
85-
"primary_key": 0,
86-
},
87-
]
88-
}
89-
90-
result = db_connect.table_columns
91-
92-
# We compare only the relevant parts because type() instances won't compare cleanly
93-
def clean(col):
94-
return {
95-
"name": col["name"],
96-
"type": type(col["type"]), # type comparison by class
97-
"nullable": col["nullable"],
98-
"default": col["default"],
99-
"primary_key": col["primary_key"],
100-
}
101-
102-
cleaned_result = {
103-
table: [clean(col) for col in cols] for table, cols in result.items()
104-
}
50+
metadata.create_all(engine)
51+
yield db_url # yield the URL instead of the engine
52+
engine.dispose()
53+
os.remove(tmp.name)
10554

106-
cleaned_expected = {
107-
table: [clean(col) for col in cols] for table, cols in expected.items()
108-
}
10955

110-
assert cleaned_result == cleaned_expected
56+
def test_extract_schema_basic(setup_database):
57+
db_url = setup_database
58+
schema = extract_db_schema(db_url)
59+
assert "tables" in schema
60+
assert "parent" in schema["tables"]
61+
assert "child" in schema["tables"]
11162

63+
parent = schema["tables"]["parent"]
64+
assert any(c["primary_key"] for c in parent["columns"])
65+
assert any(c["unique"] or c["name"] == "name" for c in parent["columns"])
66+
assert len(parent["check_constraints"]) > 0
67+
assert parent["primary_key"] == ["id"]
11268

113-
# def test_invalid_db_url():
114-
# bad_url = "sqlite:///non_existent_folder/non_existent_file.db"
115-
# with pytest.raises(OperationalError):
116-
# DBconnect(bad_url)
69+
child = schema["tables"]["child"]
70+
assert any(fk["target_table"] == "parent" for fk in child["foreign_keys"])
71+
assert child["primary_key"] == ["id"]
72+
assert any(uc["name"] == "uq_child_parent_id" for uc in child["unique_constraints"])
11773

11874

119-
def test_tables_list(db_connect):
120-
tables = db_connect.tables
121-
assert "users" in tables
75+
def test_invalid_url_raises():
76+
with pytest.raises(Exception, match="Failed to extract schema"):
77+
extract_db_schema("invalid_url")

0 commit comments

Comments
 (0)