|
1 | 1 | import pytest |
2 | | -from sqlmodel import SQLModel |
| 2 | +from sqlmodel import Session, SQLModel, create_engine |
| 3 | +from sqlmodel.pool import StaticPool |
3 | 4 |
|
4 | 5 | from src.domain.issue import Issue, IssueState |
5 | | -from src.resource_adapters.persistence.sqlmodel.database import get_engine |
6 | | -from src.resource_adapters.persistence.sqlmodel.unit_of_work import SQLModelUnitOfWork |
| 6 | +from src.resource_adapters.persistence.sqlmodel.issues import SQLModelIssueRepository |
7 | 7 |
|
8 | 8 |
|
9 | | -@pytest.fixture |
10 | | -def uow(): |
11 | | - """Create a SQLModelUnitOfWork with in-memory database.""" |
12 | | - # Reset the global engine for each test |
13 | | - import src.resource_adapters.persistence.sqlmodel.database as db |
14 | | - |
15 | | - db._engine = None |
16 | | - |
17 | | - # Create a fresh database |
18 | | - database_url = "sqlite://" |
19 | | - uow = SQLModelUnitOfWork(database_url=database_url) |
20 | | - engine = get_engine(database_url) |
21 | | - SQLModel.metadata.create_all(engine) |
22 | | - return uow |
23 | | - |
24 | | - |
25 | | -class TestSQLModelRepository: |
| 9 | +# https://sqlmodel.tiangolo.com/tutorial/fastapi/tests/ |
| 10 | +class TestSQLModelIssueRepository: |
26 | 11 | """Test suite for SQLModel-based issue repository.""" |
27 | 12 |
|
28 | | - @pytest.mark.anyio |
29 | | - async def test_add_and_get_issue(self, uow): |
| 13 | + @pytest.fixture(name="session") |
| 14 | + def session_fixture(self): |
| 15 | + engine = create_engine( |
| 16 | + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool |
| 17 | + ) |
| 18 | + SQLModel.metadata.create_all(engine) |
| 19 | + with Session(engine) as session: |
| 20 | + yield session |
| 21 | + |
| 22 | + def test_add_and_get_issue(self, session: Session): |
30 | 23 | """Test adding and retrieving an issue.""" |
31 | 24 | issue_number = 1 |
32 | 25 | test_issue = Issue(issue_number=issue_number, issue_state=IssueState.OPEN) |
33 | 26 |
|
| 27 | + uow = SQLModelIssueRepository(session) |
| 28 | + |
34 | 29 | with uow: |
35 | | - await uow.issues.add(test_issue) |
| 30 | + uow.add(test_issue) |
36 | 31 | uow.commit() |
37 | 32 |
|
38 | 33 | with uow: |
39 | | - retrieved_issue = await uow.issues.get_by_id(issue_number) |
| 34 | + retrieved_issue = uow.get_by_id(issue_number) |
40 | 35 | assert retrieved_issue.issue_number == issue_number |
41 | 36 | assert retrieved_issue.issue_state == IssueState.OPEN |
42 | 37 |
|
43 | | - @pytest.mark.anyio |
44 | | - async def test_list_issues(self, uow): |
| 38 | + def test_list_issues(self, session: Session): |
45 | 39 | """Test listing all issues.""" |
46 | 40 | issue_number = 1 |
47 | 41 | test_issue = Issue(issue_number=issue_number, issue_state=IssueState.OPEN) |
48 | 42 |
|
| 43 | + uow = SQLModelIssueRepository(session) |
| 44 | + |
49 | 45 | with uow: |
50 | | - await uow.issues.add(test_issue) |
| 46 | + uow.add(test_issue) |
51 | 47 | uow.commit() |
52 | 48 |
|
53 | 49 | with uow: |
54 | | - issues = await uow.issues.list() |
| 50 | + issues = uow.list() |
55 | 51 | assert len(issues) == 1 |
56 | 52 | assert issues[0].issue_number == issue_number |
57 | 53 |
|
58 | | - @pytest.mark.anyio |
59 | | - async def test_update_issue(self, uow): |
| 54 | + def test_update_issue(self, session: Session): |
60 | 55 | """Test updating an issue's state.""" |
61 | 56 | issue_number = 1 |
62 | 57 | test_issue = Issue(issue_number=issue_number, issue_state=IssueState.OPEN) |
63 | 58 |
|
| 59 | + uow = SQLModelIssueRepository(session) |
64 | 60 | with uow: |
65 | | - await uow.issues.add(test_issue) |
| 61 | + uow.add(test_issue) |
66 | 62 | uow.commit() |
67 | 63 |
|
68 | 64 | with uow: |
69 | | - issue_to_update = await uow.issues.get_by_id(issue_number) |
| 65 | + issue_to_update = uow.get_by_id(issue_number) |
70 | 66 | issue_to_update.issue_state = IssueState.CLOSED |
71 | | - await uow.issues.update(issue_to_update) |
| 67 | + uow.update(issue_to_update) |
72 | 68 | uow.commit() |
73 | 69 |
|
74 | 70 | with uow: |
75 | | - updated_issue = await uow.issues.get_by_id(issue_number) |
| 71 | + updated_issue = uow.get_by_id(issue_number) |
76 | 72 | assert updated_issue.issue_state == IssueState.CLOSED |
77 | 73 |
|
78 | | - @pytest.mark.anyio |
79 | | - async def test_filter_issues(self, uow): |
| 74 | + def test_filter_issues(self, session: Session): |
80 | 75 | """Test filtering issues with a predicate.""" |
81 | 76 | issue_number = 1 |
82 | 77 | test_issue = Issue(issue_number=issue_number, issue_state=IssueState.CLOSED) |
83 | 78 |
|
| 79 | + uow = SQLModelIssueRepository(session) |
84 | 80 | with uow: |
85 | | - await uow.issues.add(test_issue) |
| 81 | + uow.add(test_issue) |
86 | 82 | uow.commit() |
87 | 83 |
|
88 | 84 | with uow: |
89 | | - closed_issues = await uow.issues.list_with_predicate( |
| 85 | + closed_issues = uow.list_with_predicate( |
90 | 86 | lambda i: i.issue_state == IssueState.CLOSED |
91 | 87 | ) |
92 | 88 | assert len(closed_issues) == 1 |
93 | 89 |
|
94 | | - @pytest.mark.anyio |
95 | | - async def test_remove_issue(self, uow): |
| 90 | + def test_remove_issue(self, session: Session): |
96 | 91 | """Test removing an issue.""" |
97 | 92 | issue_number = 1 |
98 | 93 | test_issue = Issue(issue_number=issue_number, issue_state=IssueState.OPEN) |
99 | 94 |
|
| 95 | + uow = SQLModelIssueRepository(session) |
100 | 96 | with uow: |
101 | | - await uow.issues.add(test_issue) |
| 97 | + uow.add(test_issue) |
102 | 98 | uow.commit() |
103 | 99 |
|
104 | 100 | with uow: |
105 | | - issue_to_remove = await uow.issues.get_by_id(issue_number) |
106 | | - await uow.issues.remove(issue_to_remove) |
| 101 | + issue_to_remove = uow.get_by_id(issue_number) |
| 102 | + uow.remove(issue_to_remove) |
107 | 103 | uow.commit() |
108 | 104 |
|
109 | 105 | with uow: |
110 | | - remaining_issues = await uow.issues.list() |
| 106 | + remaining_issues = uow.list() |
111 | 107 | assert len(remaining_issues) == 0 |
0 commit comments