Skip to content

Commit ba8be3b

Browse files
committed
refactor(tests): streamline issue repository tests and add filters
1 parent 690dab5 commit ba8be3b

File tree

10 files changed

+348
-290
lines changed

10 files changed

+348
-290
lines changed

main.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from src.interface_adapters import api_router
1515
from src.interface_adapters.exceptions import AppException
1616
from src.interface_adapters.middleware.error_handler import app_exception_handler
17-
from src.resource_adapters.persistence.sqlmodel.database import init_db
1817

1918
# https://brandur.org/logfmt
2019
# https://github.com/Delgan/loguru
@@ -37,11 +36,8 @@ async def lifespan(app: FastAPI):
3736
logger.info(f"App state: {dict(app.state.__dict__)}")
3837
logger.info(app.state.running)
3938

40-
# Initialize database if using SQLModel
41-
if Settings.get_settings().execution_mode == "sqlmodel" and not Settings.get_settings().migrate_database:
42-
init_db()
43-
4439
yield
40+
4541
app.state.running = False
4642
logger.info("Lifespan stopped")
4743

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ authors = [{ name = "Chris Markus", email = "[email protected]" }]
66
requires-python = "~=3.12"
77
readme = "README.md"
88
dependencies = [
9+
"dynaconf>=3.2.7",
910
"fastapi[standard]>=0.115.6",
1011
"httpx>=0.28.1",
1112
"loguru>=0.7.3",
@@ -22,7 +23,9 @@ requires = ["hatchling"]
2223
build-backend = "hatchling.build"
2324

2425
[dependency-groups]
25-
dev = ["pytest>=8.3.4", "pytest-asyncio>=0.23.5", "pytest-anyio>=0.1.0"]
26+
dev = [
27+
"pytest>=8.3.4",
28+
]
2629

2730
[tool.pytest.ini_options]
2831
asyncio_mode = "auto"

src/resource_adapters/persistence/sqlmodel/database.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def get_engine(database_url: str | None = None) -> Engine:
2727
_engine = create_engine(
2828
database_url, connect_args={"check_same_thread": False}, echo=True
2929
)
30+
# Initialize database if using SQLModel
31+
if Settings.get_settings().execution_mode == "sqlmodel" and not Settings.get_settings().migrate_database:
32+
init_db()
33+
3034
return _engine
3135

3236

src/resource_adapters/persistence/sqlmodel/issues.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Callable, List
1+
from typing import Callable, List, Union
22

33
from loguru import logger
4+
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList, ClauseElement
45
from sqlmodel import Session, select
56

67
from src.app.ports.repositories.issues import IssueRepository
@@ -25,11 +26,19 @@ def list(self) -> List[Issue]:
2526
results = self.session.exec(statement).all()
2627
return results
2728

28-
def list_with_predicate(self, predicate: Callable[[Issue], bool]) -> List[Issue]:
29-
# First get all issues and then filter in memory
30-
# For better performance, specific predicates could be translated to SQL filters
31-
all_issues = self.list()
32-
return [issue for issue in all_issues if predicate(issue)]
29+
def list_with_predicate(
30+
self, predicate: Union[Callable[[Issue], bool], ClauseElement]
31+
) -> List[Issue]:
32+
if isinstance(predicate, (BinaryExpression, BooleanClauseList)):
33+
# If we're passed a SQLModel/SQLAlchemy filter condition, use it directly
34+
# open_issues = repo.list_with_predicate(Issue.issue_state == IssueState.OPEN)
35+
statement = select(Issue).where(predicate)
36+
return self.session.exec(statement).all()
37+
else:
38+
# Fall back to in-memory filtering for complex predicates that can't be expressed in SQL
39+
# open_issues = repository.list_with_predicate(lambda issue: issue.issue_state == IssueState.OPEN)
40+
all_issues = self.list()
41+
return [issue for issue in all_issues if predicate(issue)]
3342

3443
def add(self, entity: Issue) -> None:
3544
logger.info(f"adding issue: {entity.issue_number}")
@@ -45,6 +54,7 @@ def update(self, entity: Issue) -> None:
4554
# Update fields from the detached entity
4655
existing.issue_state = entity.issue_state
4756
existing.version = entity.version
57+
self.session.add(existing)
4858

4959
def remove(self, entity: Issue) -> None:
5060
logger.info(f"removing issue: {entity}")

tests/__init__.py

Whitespace-only changes.

tests/domain/test_issue.py

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,63 @@
33
from src.domain.issue import IssueState, IssueTransitionType
44

55

6-
class TestIssueState:
7-
def test_initial_state(self):
8-
"""Test that initial states are created correctly"""
9-
assert IssueState.OPEN.value == "OPEN"
10-
assert IssueState.CLOSED.value == "CLOSED"
6+
def test_initial_state():
7+
"""Test that initial states are created correctly"""
8+
assert IssueState.OPEN.value == "OPEN"
9+
assert IssueState.CLOSED.value == "CLOSED"
1110

12-
def test_is_open_property(self):
13-
"""Test the is_open property returns correct values"""
14-
assert IssueState.OPEN.is_open is True
15-
assert IssueState.CLOSED.is_open is False
11+
def test_is_open_property():
12+
"""Test the is_open property returns correct values"""
13+
assert IssueState.OPEN.is_open is True
14+
assert IssueState.CLOSED.is_open is False
1615

17-
def test_valid_transitions(self):
18-
"""Test all valid state transitions"""
19-
# Test closing an open issue
20-
state = IssueState.OPEN
21-
new_state = state.transition(IssueTransitionType.CLOSE_AS_COMPLETE)
22-
assert new_state == IssueState.CLOSED
16+
def test_valid_transitions():
17+
"""Test all valid state transitions"""
18+
# Test closing an open issue
19+
state = IssueState.OPEN
20+
new_state = state.transition(IssueTransitionType.CLOSE_AS_COMPLETE)
21+
assert new_state == IssueState.CLOSED
2322

24-
state = IssueState.OPEN
25-
new_state = state.transition(IssueTransitionType.CLOSE_AS_NOT_PLANNED)
26-
assert new_state == IssueState.CLOSED
23+
state = IssueState.OPEN
24+
new_state = state.transition(IssueTransitionType.CLOSE_AS_NOT_PLANNED)
25+
assert new_state == IssueState.CLOSED
2726

28-
# Test reopening a closed issue
29-
state = IssueState.CLOSED
30-
new_state = state.transition(IssueTransitionType.REOPEN)
31-
assert new_state == IssueState.OPEN
27+
# Test reopening a closed issue
28+
state = IssueState.CLOSED
29+
new_state = state.transition(IssueTransitionType.REOPEN)
30+
assert new_state == IssueState.OPEN
3231

33-
def test_invalid_transitions(self):
34-
"""Test that invalid transitions raise appropriate errors"""
35-
# Test cannot close an already closed issue
36-
with pytest.raises(ValueError) as exc_info:
37-
IssueState.CLOSED.transition(IssueTransitionType.CLOSE_AS_COMPLETE)
38-
assert "Cannot perform CLOSE_AS_COMPLETE transition from state CLOSED" in str(
39-
exc_info.value
40-
)
32+
def test_invalid_transitions():
33+
"""Test that invalid transitions raise appropriate errors"""
34+
# Test cannot close an already closed issue
35+
with pytest.raises(ValueError) as exc_info:
36+
IssueState.CLOSED.transition(IssueTransitionType.CLOSE_AS_COMPLETE)
37+
assert "Cannot perform CLOSE_AS_COMPLETE transition from state CLOSED" in str(
38+
exc_info.value
39+
)
4140

42-
# Test cannot reopen an already open issue
43-
with pytest.raises(ValueError) as exc_info:
44-
IssueState.OPEN.transition(IssueTransitionType.REOPEN)
45-
assert "Cannot perform REOPEN transition from state OPEN" in str(exc_info.value)
41+
# Test cannot reopen an already open issue
42+
with pytest.raises(ValueError) as exc_info:
43+
IssueState.OPEN.transition(IssueTransitionType.REOPEN)
44+
assert "Cannot perform REOPEN transition from state OPEN" in str(exc_info.value)
4645

47-
def test_unknown_transition_type(self):
48-
"""Test that using an undefined transition type raises an error"""
46+
def test_unknown_transition_type():
47+
"""Test that using an undefined transition type raises an error"""
4948

50-
# Create a new enum value that isn't in the transitions dictionary
51-
class FakeTransitionType:
52-
def __str__(self):
53-
return "FAKE_TRANSITION"
49+
# Create a new enum value that isn't in the transitions dictionary
50+
class FakeTransitionType:
51+
def __str__(self):
52+
return "FAKE_TRANSITION"
5453

55-
with pytest.raises(ValueError) as exc_info:
56-
IssueState.OPEN.transition(FakeTransitionType())
57-
assert "Unknown transition type: FAKE_TRANSITION" in str(exc_info.value)
54+
with pytest.raises(ValueError) as exc_info:
55+
IssueState.OPEN.transition(FakeTransitionType())
56+
assert "Unknown transition type: FAKE_TRANSITION" in str(exc_info.value)
5857

59-
def test_transitions_immutability(self):
60-
"""Test that the transitions dictionary cannot be modified at runtime"""
61-
transitions = IssueState.transitions()
62-
with pytest.raises((TypeError, AttributeError)):
63-
transitions[IssueTransitionType.REOPEN] = {
64-
"from": IssueState.OPEN.value,
65-
"to": IssueState.CLOSED.value,
66-
}
58+
def test_transitions_immutability():
59+
"""Test that the transitions dictionary cannot be modified at runtime"""
60+
transitions = IssueState.transitions()
61+
with pytest.raises((TypeError, AttributeError)):
62+
transitions[IssueTransitionType.REOPEN] = {
63+
"from": IssueState.OPEN.value,
64+
"to": IssueState.CLOSED.value,
65+
}

tests/test_issues.py

Lines changed: 79 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from fastapi.testclient import TestClient
21
import pytest
3-
from sqlmodel import SQLModel, Session, create_engine
2+
from fastapi.testclient import TestClient
3+
from sqlmodel import Session, SQLModel, create_engine
44
from sqlmodel.pool import StaticPool
55

66
from main import app
7+
from src.app.usecases.analyze_issue import AnalyzeIssue
78
from src.domain.issue import Issue, IssueState
9+
from src.interface_adapters.exceptions import NotFoundException
810
from src.resource_adapters.persistence.sqlmodel.database import get_db
911
from src.resource_adapters.persistence.sqlmodel.issues import SQLModelIssueRepository
10-
from src.app.usecases.analyze_issue import AnalyzeIssue
11-
from src.interface_adapters.exceptions import NotFoundException
1212

1313

1414
# https://sqlmodel.tiangolo.com/tutorial/fastapi/tests/
@@ -24,71 +24,78 @@ def session_fixture():
2424

2525
@pytest.fixture(name="client")
2626
def client_fixture(session: Session):
27-
def get_session_override():
28-
return session
29-
30-
app.dependency_overrides[get_db] = get_session_override
31-
app.state.running = True
32-
client = TestClient(app)
33-
yield client
34-
app.dependency_overrides.clear()
35-
app.state.running = False
36-
37-
38-
class TestAnalyzeIssue:
39-
def test_analyze_issue_command(self, client: TestClient, session: Session):
40-
# Test case 1: Successful analysis
41-
issue_number = 1
42-
test_issue = Issue(issue_number=issue_number, issue_state=IssueState.OPEN)
43-
44-
repository = SQLModelIssueRepository(session)
45-
46-
repository.add(test_issue)
47-
repository.commit()
48-
retrieved_issue = repository.get_by_id(issue_number)
49-
assert retrieved_issue.issue_number == issue_number
50-
51-
use_case = AnalyzeIssue(issue_number=issue_number, repo=repository)
52-
response = use_case.analyze()
53-
assert response.issue_number == issue_number
54-
55-
def test_analyze_issue_client(self, client: TestClient, session: Session):
56-
# Test case 1: Successful analysis
57-
response = client.post("/issues/1/analyze")
58-
#assert response.status_code == 200
59-
assert response.json() == {"version": 1, "issue_number": 1}
60-
61-
def test_analyze_issue_not_found(self, client: TestClient, session: Session):
62-
# Test case 1: Successful analysis
63-
issue_number = 1
64-
65-
repository = SQLModelIssueRepository(session)
66-
retrieved_issue = repository.get_by_id(issue_number)
67-
assert retrieved_issue.issue_number == 0
68-
69-
use_case = AnalyzeIssue(issue_number=issue_number, repo=repository)
70-
with pytest.raises(NotFoundException) as exc_info:
71-
use_case.analyze()
72-
assert exc_info.value.message == "Issue not found"
73-
74-
response = client.post("/issues/1/analyze")
75-
assert response.status_code == 404
76-
77-
def test_analyze_issue_invalid_number(self, client: TestClient, session: Session):
78-
"""Test analyzing an issue with an invalid issue number."""
79-
80-
response = client.post("/issues/abc/analyze")
81-
assert response.status_code == 422
82-
# Validate error response structure
83-
error_detail = response.json()["detail"]
84-
assert isinstance(error_detail, list)
85-
assert error_detail[0]["type"] == "int_parsing"
86-
assert error_detail[0]["loc"] == ["path", "issue_number"]
87-
88-
def test_analyze_issue_unauthorized(self, client: TestClient):
89-
# Test case 3: Unauthorized access
90-
response = client.post("/issues/456/analyze")
91-
assert response.status_code == 401
92-
assert response.js_on() == {"detail": "Unauthorized"}
93-
94-
# Add more test cases as needed
27+
# def get_session_override():
28+
# return session
29+
30+
# app.dependency_overrides[get_db] = get_session_override
31+
with TestClient(app, raise_server_exceptions=False) as client:
32+
yield client
33+
#app.dependency_overrides.clear()
34+
35+
36+
def test_analyze_issue_command(client: TestClient, session: Session):
37+
# Test case 1: Successful analysis
38+
issue_number = 1
39+
test_issue = Issue(issue_number=issue_number, issue_state=IssueState.OPEN)
40+
41+
repository = SQLModelIssueRepository(session)
42+
43+
repository.add(test_issue)
44+
repository.commit()
45+
retrieved_issue = repository.get_by_id(issue_number)
46+
assert retrieved_issue.issue_number == issue_number
47+
48+
use_case = AnalyzeIssue(issue_number=issue_number, repo=repository)
49+
response = use_case.analyze()
50+
assert response.issue_number == issue_number
51+
52+
53+
def test_analyze_issue_client(client: TestClient, session: Session):
54+
# Test case 1: Successful analysis
55+
issue_number = 1
56+
test_issue = Issue(issue_number=issue_number, issue_state=IssueState.OPEN)
57+
58+
repository = SQLModelIssueRepository(session)
59+
60+
repository.add(test_issue)
61+
repository.commit()
62+
63+
response = client.post("/issues/1/analyze")
64+
# assert response.status_code == 200
65+
assert response.json() == {"version": 1, "issue_number": 1}
66+
67+
68+
def test_analyze_issue_not_found(client: TestClient, session: Session):
69+
# Test case 1: Successful analysis
70+
issue_number = 1
71+
72+
repository = SQLModelIssueRepository(session)
73+
retrieved_issue = repository.get_by_id(issue_number)
74+
assert retrieved_issue.issue_number == 0
75+
76+
use_case = AnalyzeIssue(issue_number=issue_number, repo=repository)
77+
with pytest.raises(NotFoundException) as exc_info:
78+
use_case.analyze()
79+
assert exc_info.value.message == "Issue not found"
80+
81+
response = client.post("/issues/1/analyze")
82+
assert response.status_code == 404
83+
84+
85+
def test_analyze_issue_invalid_number(client: TestClient, session: Session):
86+
"""Test analyzing an issue with an invalid issue number."""
87+
88+
response = client.post("/issues/abc/analyze")
89+
assert response.status_code == 422
90+
# Validate error response structure
91+
error_detail = response.json()["detail"]
92+
assert isinstance(error_detail, list)
93+
assert error_detail[0]["type"] == "int_parsing"
94+
assert error_detail[0]["loc"] == ["path", "issue_number"]
95+
96+
97+
def test_analyze_issue_unauthorized(client: TestClient):
98+
# Test case 3: Unauthorized access
99+
response = client.post("/issues/456/analyze")
100+
assert response.status_code == 401
101+
assert response.js_on() == {"detail": "Unauthorized"}

0 commit comments

Comments
 (0)