Skip to content

Commit ca9e7ae

Browse files
Improve SA error handling (#689)
1 parent 52c8ae4 commit ca9e7ae

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

aiohttp_admin/backends/sqlalchemy.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import logging
33
import operator
4-
from typing import Any, Iterator, Type, Union
4+
import sys
5+
from typing import Any, Callable, Coroutine, Iterator, Type, TypeVar, Union
56

67
import sqlalchemy as sa
78
from aiohttp import web
@@ -13,6 +14,14 @@
1314
AbstractAdminResource, CreateParams, DeleteManyParams, DeleteParams, GetListParams,
1415
GetManyParams, GetOneParams, Record, UpdateManyParams, UpdateParams)
1516

17+
if sys.version_info >= (3, 10):
18+
from typing import ParamSpec
19+
else:
20+
from typing_extensions import ParamSpec
21+
22+
_P = ParamSpec("_P")
23+
_T = TypeVar("_T")
24+
1625
logger = logging.getLogger(__name__)
1726

1827
FIELD_TYPES = {
@@ -26,6 +35,23 @@
2635
}
2736

2837

38+
def handle_errors(
39+
f: Callable[_P, Coroutine[None, None, _T]]
40+
) -> Callable[_P, Coroutine[None, None, _T]]:
41+
async def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:
42+
try:
43+
return await f(*args, **kwargs)
44+
except sa.exc.IntegrityError as e:
45+
raise web.HTTPBadRequest(reason=e.args[0])
46+
except sa.exc.NoResultFound:
47+
logger.warning("No result found (%s)", args, exc_info=True)
48+
raise web.HTTPNotFound()
49+
except sa.exc.CompileError as e:
50+
logger.warning("CompileError (%s)", args, exc_info=True)
51+
raise web.HTTPBadRequest(reason=str(e))
52+
return inner
53+
54+
2955
def create_filters(columns: sa.ColumnCollection[str, sa.Column[object]],
3056
filters: dict[str, object]) -> Iterator[ExpressionElementRole[Any]]:
3157
return (columns[k].in_(v) if isinstance(v, list)
@@ -95,6 +121,7 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, Type[Declara
95121

96122
super().__init__()
97123

124+
@handle_errors
98125
async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:
99126
per_page = params["pagination"]["perPage"]
100127
offset = (params["pagination"]["page"] - 1) * per_page
@@ -115,16 +142,14 @@ async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:
115142

116143
return entities, count
117144

145+
@handle_errors
118146
async def get_one(self, params: GetOneParams) -> Record:
119147
async with self._db.connect() as conn:
120148
stmt = sa.select(self._table).where(self._table.c[self.primary_key] == params["id"])
121149
result = await conn.execute(stmt)
122-
try:
123-
return result.one()._asdict()
124-
except sa.exc.NoResultFound:
125-
logger.warning("No result found (%s)", params["id"], exc_info=True)
126-
raise web.HTTPNotFound()
150+
return result.one()._asdict()
127151

152+
@handle_errors
128153
async def get_many(self, params: GetManyParams) -> list[Record]:
129154
async with self._db.connect() as conn:
130155
stmt = sa.select(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
@@ -134,6 +159,7 @@ async def get_many(self, params: GetManyParams) -> list[Record]:
134159
return records
135160
raise web.HTTPNotFound()
136161

162+
@handle_errors
137163
async def create(self, params: CreateParams) -> Record:
138164
async with self._db.begin() as conn:
139165
stmt = sa.insert(self._table).values(params["data"]).returning(*self._table.c)
@@ -144,44 +170,32 @@ async def create(self, params: CreateParams) -> Record:
144170
raise web.HTTPBadRequest(reason="Integrity error (element already exists?)")
145171
return row.one()._asdict()
146172

173+
@handle_errors
147174
async def update(self, params: UpdateParams) -> Record:
148175
async with self._db.begin() as conn:
149176
stmt = sa.update(self._table).where(self._table.c[self.primary_key] == params["id"])
150177
stmt = stmt.values(params["data"]).returning(*self._table.c)
151-
try:
152-
row = await conn.execute(stmt)
153-
except sa.exc.CompileError as e:
154-
logger.warning("CompileError (%s)", params["id"], exc_info=True)
155-
raise web.HTTPBadRequest(reason=str(e))
156-
try:
157-
return row.one()._asdict()
158-
except sa.exc.NoResultFound:
159-
logger.warning("No result found (%s)", params["id"], exc_info=True)
160-
raise web.HTTPNotFound()
178+
row = await conn.execute(stmt)
179+
return row.one()._asdict()
161180

181+
@handle_errors
162182
async def update_many(self, params: UpdateManyParams) -> list[Union[str, int]]:
163183
async with self._db.begin() as conn:
164184
stmt = sa.update(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
165185
stmt = stmt.values(params["data"]).returning(self._table.c[self.primary_key])
166-
try:
167-
r = await conn.scalars(stmt)
168-
except sa.exc.CompileError as e:
169-
logger.warning("CompileError (%s)", params["ids"], exc_info=True)
170-
raise web.HTTPBadRequest(reason=str(e))
186+
r = await conn.scalars(stmt)
171187
# The security check has already called get_many(), so we can be sure
172188
# there will be results here.
173189
return list(r)
174190

191+
@handle_errors
175192
async def delete(self, params: DeleteParams) -> Record:
176193
async with self._db.begin() as conn:
177194
stmt = sa.delete(self._table).where(self._table.c[self.primary_key] == params["id"])
178195
row = await conn.execute(stmt.returning(*self._table.c))
179-
try:
180-
return row.one()._asdict()
181-
except sa.exc.NoResultFound:
182-
logger.warning("No result found (%s)", params["id"], exc_info=True)
183-
raise web.HTTPNotFound()
196+
return row.one()._asdict()
184197

198+
@handle_errors
185199
async def delete_many(self, params: DeleteManyParams) -> list[Union[str, int]]:
186200
async with self._db.begin() as conn:
187201
stmt = sa.delete(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ pytest==7.3.1
99
pytest-aiohttp==1.0.4
1010
pytest-cov==4.0.0
1111
sqlalchemy==2.0.9
12+
typing_extensions>=3.10; python_version<"3.10"

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def read_version():
4141
license="Apache 2",
4242
packages=find_packages(),
4343
install_requires=("aiohttp>=3.8.2", "aiohttp_security", "aiohttp_session",
44-
"cryptography", "pydantic"),
44+
"cryptography", "pydantic",
45+
'typing_extensions>=3.10; python_version<"3.10"'),
4546
extras_require={"sa": ["sqlalchemy>=2.0.4,<3"]},
4647
include_package_data=True)

0 commit comments

Comments
 (0)