Skip to content

Commit c08ef25

Browse files
authored
fix: error message handling and isolation in repositories (#605)
Correct error message retrieval and ensure that error message overrides are isolated for different repository instances. This improves the clarity and reliability of error messages across the application.
1 parent 3dc01e2 commit c08ef25

File tree

7 files changed

+183
-21
lines changed

7 files changed

+183
-21
lines changed

advanced_alchemy/repository/_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def _get_error_messages(
586586
error_messages = None
587587
if default_messages == Empty:
588588
default_messages = None
589-
messages = DEFAULT_ERROR_MESSAGE_TEMPLATES
589+
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
590590
if default_messages and isinstance(default_messages, dict):
591591
messages.update(default_messages)
592592
if error_messages:

advanced_alchemy/repository/_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def _get_error_messages(
587587
error_messages = None
588588
if default_messages == Empty:
589589
default_messages = None
590-
messages = DEFAULT_ERROR_MESSAGE_TEMPLATES
590+
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
591591
if default_messages and isinstance(default_messages, dict):
592592
messages.update(default_messages)
593593
if error_messages:

advanced_alchemy/repository/memory/_async.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def __init__(
9999
self.auto_expunge = auto_expunge
100100
self.auto_refresh = auto_refresh
101101
self.auto_commit = auto_commit
102-
self.error_messages = self._get_error_messages(error_messages=error_messages)
102+
self.error_messages = self._get_error_messages(
103+
error_messages=error_messages, default_messages=self.error_messages
104+
)
103105
self.wrap_exceptions = wrap_exceptions
104106
self.order_by = order_by
105107
self._dialect: Dialect = create_autospec(Dialect, instance=True)
@@ -121,13 +123,14 @@ def _get_error_messages(
121123
) -> Optional[ErrorMessages]:
122124
if error_messages == Empty:
123125
error_messages = None
124-
default_messages = cast(
125-
"Optional[ErrorMessages]",
126-
default_messages if default_messages != Empty else DEFAULT_ERROR_MESSAGE_TEMPLATES,
127-
)
128-
if error_messages is not None and default_messages is not None:
129-
default_messages.update(cast("ErrorMessages", error_messages))
130-
return default_messages
126+
if default_messages == Empty:
127+
default_messages = None
128+
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
129+
if default_messages:
130+
messages.update(cast("ErrorMessages", default_messages))
131+
if error_messages:
132+
messages.update(cast("ErrorMessages", error_messages))
133+
return messages
131134

132135
@classmethod
133136
def __database_add__(cls, identity: Any, data: ModelT) -> ModelT:

advanced_alchemy/repository/memory/_sync.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def __init__(
100100
self.auto_expunge = auto_expunge
101101
self.auto_refresh = auto_refresh
102102
self.auto_commit = auto_commit
103-
self.error_messages = self._get_error_messages(error_messages=error_messages)
103+
self.error_messages = self._get_error_messages(
104+
error_messages=error_messages, default_messages=self.error_messages
105+
)
104106
self.wrap_exceptions = wrap_exceptions
105107
self.order_by = order_by
106108
self._dialect: Dialect = create_autospec(Dialect, instance=True)
@@ -122,13 +124,14 @@ def _get_error_messages(
122124
) -> Optional[ErrorMessages]:
123125
if error_messages == Empty:
124126
error_messages = None
125-
default_messages = cast(
126-
"Optional[ErrorMessages]",
127-
default_messages if default_messages != Empty else DEFAULT_ERROR_MESSAGE_TEMPLATES,
128-
)
129-
if error_messages is not None and default_messages is not None:
130-
default_messages.update(cast("ErrorMessages", error_messages))
131-
return default_messages
127+
if default_messages == Empty:
128+
default_messages = None
129+
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
130+
if default_messages:
131+
messages.update(cast("ErrorMessages", default_messages))
132+
if error_messages:
133+
messages.update(cast("ErrorMessages", error_messages))
134+
return messages
132135

133136
@classmethod
134137
def __database_add__(cls, identity: Any, data: ModelT) -> ModelT:

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ pythonVersion = "3.9"
522522
reportMissingTypeStubs = false
523523
reportPrivateImportUsage = false
524524
reportUnknownMemberType = false
525-
reportUnnecessaryTypeIgnoreComments = true
526525
reportUnusedFunction = false
527526
strict = ["advanced_alchemy/**/*"]
528527
venv = ".venv"

tests/integration/test_repository.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import asyncio
44
import datetime
55
from collections.abc import Generator
6-
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
6+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
7+
from unittest.mock import create_autospec
78
from uuid import UUID
89

910
import pytest
10-
from sqlalchemy.ext.asyncio import AsyncSession
11+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
1112
from sqlalchemy.orm import Session
1213
from time_machine import travel
1314

@@ -18,6 +19,7 @@
1819
SearchFilter,
1920
)
2021
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
22+
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES
2123
from advanced_alchemy.repository.memory import (
2224
SQLAlchemyAsyncMockRepository,
2325
SQLAlchemySyncMockRepository,
@@ -635,6 +637,64 @@ async def test_repo_error_messages(seeded_test_session_async: "tuple[AsyncSessio
635637
await maybe_async(author_repo.get(non_existent_id))
636638

637639

640+
def _make_mock_session(engine: AsyncEngine) -> AsyncSession:
641+
session = cast(AsyncSession, create_autospec(AsyncSession, instance=True))
642+
session.bind = engine
643+
session.get_bind.return_value = engine
644+
return session
645+
646+
647+
@pytest.mark.mock_async
648+
def test_repo_error_message_overrides_are_isolated(
649+
mock_async_engine: AsyncEngine, uuid_models_dba: "dict[str, type]"
650+
) -> None:
651+
author_model = cast(type[Any], uuid_models_dba["author"])
652+
default_not_found = DEFAULT_ERROR_MESSAGE_TEMPLATES.get("not_found")
653+
654+
class BaseRepo(SQLAlchemyAsyncRepository[Any]):
655+
model_type = author_model
656+
657+
class RepoA(BaseRepo):
658+
error_messages = {"not_found": "Author A not found"}
659+
660+
class RepoB(BaseRepo):
661+
error_messages = {"not_found": "Author B not found"}
662+
663+
repo_a_first = RepoA(session=_make_mock_session(mock_async_engine))
664+
repo_b = RepoB(session=_make_mock_session(mock_async_engine))
665+
repo_a_second = RepoA(session=_make_mock_session(mock_async_engine))
666+
667+
assert repo_a_first.error_messages is not DEFAULT_ERROR_MESSAGE_TEMPLATES
668+
assert repo_a_first.error_messages is not repo_b.error_messages
669+
assert repo_a_first.error_messages["not_found"] == "Author A not found"
670+
assert repo_b.error_messages["not_found"] == "Author B not found"
671+
assert repo_a_second.error_messages["not_found"] == "Author A not found"
672+
assert DEFAULT_ERROR_MESSAGE_TEMPLATES["not_found"] == default_not_found
673+
674+
675+
@pytest.mark.mock_async
676+
def test_mock_repo_error_message_overrides_are_isolated(
677+
mock_async_engine: AsyncEngine, uuid_models_dba: "dict[str, type]"
678+
) -> None:
679+
author_model = cast(type[Any], uuid_models_dba["author"])
680+
681+
class BaseMockRepo(SQLAlchemyAsyncMockRepository[Any]):
682+
model_type = author_model
683+
684+
class RepoA(BaseMockRepo):
685+
error_messages = {"not_found": "Mock Author A not found"}
686+
687+
class RepoB(BaseMockRepo):
688+
error_messages = {"not_found": "Mock Author B not found"}
689+
690+
repo_a = RepoA(session=_make_mock_session(mock_async_engine))
691+
repo_b = RepoB(session=_make_mock_session(mock_async_engine))
692+
693+
assert repo_a.error_messages is not repo_b.error_messages
694+
assert repo_a.error_messages["not_found"] == "Mock Author A not found"
695+
assert repo_b.error_messages["not_found"] == "Mock Author B not found"
696+
697+
638698
# Comprehensive tests for GitHub issue #535 and bug_fix.md issues
639699
async def test_service_pydantic_partial_update_github_535(
640700
seeded_test_session_async: "tuple[AsyncSession, dict[str, type]]",
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, cast
4+
from unittest.mock import create_autospec
5+
6+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
7+
from sqlalchemy.orm import Session
8+
9+
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES
10+
from advanced_alchemy.repository.memory import (
11+
SQLAlchemyAsyncMockRepository,
12+
SQLAlchemySyncMockRepository,
13+
)
14+
15+
16+
def _make_async_session() -> AsyncSession:
17+
session = cast(AsyncSession, create_autospec(AsyncSession, instance=True))
18+
engine = cast(AsyncEngine, create_autospec(AsyncEngine, instance=True))
19+
engine.dialect.name = "mock"
20+
session.bind = engine
21+
session.get_bind.return_value = engine
22+
return session
23+
24+
25+
def _make_sync_session() -> Session:
26+
session = cast(Session, create_autospec(Session, instance=True))
27+
session.bind = cast(Any, create_autospec(object, instance=True))
28+
return session
29+
30+
31+
def test_async_mock_repository_error_messages_isolated() -> None:
32+
class BaseRepo(SQLAlchemyAsyncMockRepository[Any]):
33+
model_type = object
34+
35+
class RepoA(BaseRepo):
36+
error_messages = {"not_found": "Async Repo A"}
37+
38+
class RepoB(BaseRepo):
39+
error_messages = {"not_found": "Async Repo B"}
40+
41+
repo_a_first = RepoA(session=_make_async_session())
42+
repo_b = RepoB(session=_make_async_session())
43+
repo_a_second = RepoA(session=_make_async_session())
44+
45+
assert repo_a_first.error_messages is not DEFAULT_ERROR_MESSAGE_TEMPLATES
46+
assert repo_a_first.error_messages is not repo_b.error_messages
47+
assert repo_a_first.error_messages["not_found"] == "Async Repo A"
48+
assert repo_b.error_messages["not_found"] == "Async Repo B"
49+
assert repo_a_second.error_messages["not_found"] == "Async Repo A"
50+
assert DEFAULT_ERROR_MESSAGE_TEMPLATES["not_found"] == "The requested resource was not found"
51+
52+
53+
def test_async_mock_repository_instance_override_does_not_mutate_class() -> None:
54+
class Repo(SQLAlchemyAsyncMockRepository[Any]):
55+
model_type = object
56+
error_messages = {"other": "default other"}
57+
58+
repo_custom = Repo(session=_make_async_session(), error_messages={"other": "custom other"})
59+
repo_plain = Repo(session=_make_async_session())
60+
61+
assert repo_custom.error_messages["other"] == "custom other"
62+
assert repo_plain.error_messages["other"] == "default other"
63+
assert Repo.error_messages["other"] == "default other"
64+
65+
66+
def test_sync_mock_repository_error_messages_isolated() -> None:
67+
class BaseRepo(SQLAlchemySyncMockRepository[Any]):
68+
model_type = object
69+
70+
class RepoA(BaseRepo):
71+
error_messages = {"not_found": "Sync Repo A"}
72+
73+
class RepoB(BaseRepo):
74+
error_messages = {"not_found": "Sync Repo B"}
75+
76+
repo_a_first = RepoA(session=_make_sync_session())
77+
repo_b = RepoB(session=_make_sync_session())
78+
repo_a_second = RepoA(session=_make_sync_session())
79+
80+
assert repo_a_first.error_messages is not DEFAULT_ERROR_MESSAGE_TEMPLATES
81+
assert repo_a_first.error_messages is not repo_b.error_messages
82+
assert repo_a_first.error_messages["not_found"] == "Sync Repo A"
83+
assert repo_b.error_messages["not_found"] == "Sync Repo B"
84+
assert repo_a_second.error_messages["not_found"] == "Sync Repo A"
85+
86+
87+
def test_sync_mock_repository_instance_override_does_not_mutate_class() -> None:
88+
class Repo(SQLAlchemySyncMockRepository[Any]):
89+
model_type = object
90+
error_messages = {"duplicate_key": "sync default"}
91+
92+
repo_custom = Repo(session=_make_sync_session(), error_messages={"duplicate_key": "custom sync"})
93+
repo_plain = Repo(session=_make_sync_session())
94+
95+
assert repo_custom.error_messages["duplicate_key"] == "custom sync"
96+
assert repo_plain.error_messages["duplicate_key"] == "sync default"
97+
assert Repo.error_messages["duplicate_key"] == "sync default"

0 commit comments

Comments
 (0)