|
1 | 1 | import os |
2 | | -import sqlite3 |
| 2 | +import tempfile |
3 | 3 |
|
4 | 4 | import pytest |
5 | | -from sqlalchemy import INTEGER, TEXT |
6 | | - |
7 | | -from rowgen.extract_from_db import DBconnect |
8 | | - |
9 | | - |
10 | | -@pytest.fixture |
11 | | -def temp_db(): |
12 | | - db_path = "test.db" |
13 | | - conn = sqlite3.connect(db_path) |
14 | | - cursor = conn.cursor() |
15 | | - cursor.execute( |
16 | | - """ |
17 | | - CREATE TABLE users ( |
18 | | - id INTEGER PRIMARY KEY, |
19 | | - name TEXT NOT NULL, |
20 | | - username TEXT UNIQUE, |
21 | | - rank INTEGER); |
22 | | - """ |
| 5 | +from sqlalchemy import ( |
| 6 | + Column, |
| 7 | + Integer, |
| 8 | + String, |
| 9 | + ForeignKey, |
| 10 | + CheckConstraint, |
| 11 | + UniqueConstraint, |
| 12 | + create_engine, |
| 13 | + MetaData, |
| 14 | + Table, |
| 15 | +) |
| 16 | + |
| 17 | +from rowgen.extract_from_db import ( |
| 18 | + extract_db_schema, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +@pytest.fixture(scope="function") |
| 23 | +def sqlite_db_url(): |
| 24 | + return "sqlite:///:memory:" |
| 25 | + |
| 26 | + |
| 27 | +@pytest.fixture(scope="function") |
| 28 | +def setup_database(): |
| 29 | + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: |
| 30 | + db_url = f"sqlite:///{tmp.name}" |
| 31 | + engine = create_engine(db_url) |
| 32 | + metadata = MetaData() |
| 33 | + |
| 34 | + parent = Table( |
| 35 | + "parent", |
| 36 | + metadata, |
| 37 | + Column("id", Integer, primary_key=True, autoincrement=True), |
| 38 | + Column("name", String, nullable=False, unique=True), |
| 39 | + CheckConstraint("length(name) > 1", name="name_length_check"), |
23 | 40 | ) |
24 | | - conn.commit() |
25 | | - conn.close() |
26 | | - url = f"sqlite:///{db_path}" |
27 | | - yield url |
28 | | - # Cleanup after test |
29 | | - if os.path.exists(db_path): |
30 | | - os.remove(db_path) |
31 | | - |
32 | | - |
33 | | -@pytest.fixture |
34 | | -def db_connect(temp_db): |
35 | | - dbc = DBconnect(temp_db) |
36 | | - yield dbc |
37 | | - |
38 | | - |
39 | | -@pytest.fixture |
40 | | -def empty_temp_db(): |
41 | | - db_path = "empty_test.db" |
42 | | - conn = sqlite3.connect(db_path) |
43 | | - conn.close() |
44 | | - url = f"sqlite:///{db_path}" |
45 | | - yield url |
46 | | - if os.path.exists(db_path): |
47 | | - os.remove(db_path) |
48 | 41 |
|
| 42 | + child = Table( |
| 43 | + "child", |
| 44 | + metadata, |
| 45 | + Column("id", Integer, primary_key=True), |
| 46 | + Column("parent_id", Integer, ForeignKey("parent.id", ondelete="CASCADE")), |
| 47 | + UniqueConstraint("parent_id", name="uq_child_parent_id"), |
| 48 | + ) |
49 | 49 |
|
50 | | -# def test_empty_db_table_columns(empty_temp_db): |
51 | | -# dbc = DBconnect(empty_temp_db) |
52 | | -# assert dbc.table_columns == {} |
53 | | - |
54 | | - |
55 | | -def test_get_columns(db_connect): |
56 | | - # cols = db_connect.table_columns |
57 | | - expected = { |
58 | | - "users": [ |
59 | | - { |
60 | | - "name": "id", |
61 | | - "type": INTEGER(), |
62 | | - "nullable": True, |
63 | | - "default": None, |
64 | | - "primary_key": 1, |
65 | | - }, |
66 | | - { |
67 | | - "name": "name", |
68 | | - "type": TEXT(), |
69 | | - "nullable": False, |
70 | | - "default": None, |
71 | | - "primary_key": 0, |
72 | | - }, |
73 | | - { |
74 | | - "name": "username", |
75 | | - "type": TEXT(), |
76 | | - "nullable": True, |
77 | | - "default": None, |
78 | | - "primary_key": 0, |
79 | | - }, |
80 | | - { |
81 | | - "name": "rank", |
82 | | - "type": INTEGER(), |
83 | | - "nullable": True, |
84 | | - "default": None, |
85 | | - "primary_key": 0, |
86 | | - }, |
87 | | - ] |
88 | | - } |
89 | | - |
90 | | - result = db_connect.table_columns |
91 | | - |
92 | | - # We compare only the relevant parts because type() instances won't compare cleanly |
93 | | - def clean(col): |
94 | | - return { |
95 | | - "name": col["name"], |
96 | | - "type": type(col["type"]), # type comparison by class |
97 | | - "nullable": col["nullable"], |
98 | | - "default": col["default"], |
99 | | - "primary_key": col["primary_key"], |
100 | | - } |
101 | | - |
102 | | - cleaned_result = { |
103 | | - table: [clean(col) for col in cols] for table, cols in result.items() |
104 | | - } |
| 50 | + metadata.create_all(engine) |
| 51 | + yield db_url # yield the URL instead of the engine |
| 52 | + engine.dispose() |
| 53 | + os.remove(tmp.name) |
105 | 54 |
|
106 | | - cleaned_expected = { |
107 | | - table: [clean(col) for col in cols] for table, cols in expected.items() |
108 | | - } |
109 | 55 |
|
110 | | - assert cleaned_result == cleaned_expected |
| 56 | +def test_extract_schema_basic(setup_database): |
| 57 | + db_url = setup_database |
| 58 | + schema = extract_db_schema(db_url) |
| 59 | + assert "tables" in schema |
| 60 | + assert "parent" in schema["tables"] |
| 61 | + assert "child" in schema["tables"] |
111 | 62 |
|
| 63 | + parent = schema["tables"]["parent"] |
| 64 | + assert any(c["primary_key"] for c in parent["columns"]) |
| 65 | + assert any(c["unique"] or c["name"] == "name" for c in parent["columns"]) |
| 66 | + assert len(parent["check_constraints"]) > 0 |
| 67 | + assert parent["primary_key"] == ["id"] |
112 | 68 |
|
113 | | -# def test_invalid_db_url(): |
114 | | -# bad_url = "sqlite:///non_existent_folder/non_existent_file.db" |
115 | | -# with pytest.raises(OperationalError): |
116 | | -# DBconnect(bad_url) |
| 69 | + child = schema["tables"]["child"] |
| 70 | + assert any(fk["target_table"] == "parent" for fk in child["foreign_keys"]) |
| 71 | + assert child["primary_key"] == ["id"] |
| 72 | + assert any(uc["name"] == "uq_child_parent_id" for uc in child["unique_constraints"]) |
117 | 73 |
|
118 | 74 |
|
119 | | -def test_tables_list(db_connect): |
120 | | - tables = db_connect.tables |
121 | | - assert "users" in tables |
| 75 | +def test_invalid_url_raises(): |
| 76 | + with pytest.raises(Exception, match="Failed to extract schema"): |
| 77 | + extract_db_schema("invalid_url") |
0 commit comments