11import asyncio
22import logging
33import operator
4- from typing import Any , Iterator , Type , Union
4+ import sys
5+ from typing import Any , Callable , Coroutine , Iterator , Type , TypeVar , Union
56
67import sqlalchemy as sa
78from aiohttp import web
1314 AbstractAdminResource , CreateParams , DeleteManyParams , DeleteParams , GetListParams ,
1415 GetManyParams , GetOneParams , Record , UpdateManyParams , UpdateParams )
1516
17+ if sys .version_info >= (3 , 10 ):
18+ from typing import ParamSpec
19+ else :
20+ from typing_extensions import ParamSpec
21+
22+ _P = ParamSpec ("_P" )
23+ _T = TypeVar ("_T" )
24+
1625logger = logging .getLogger (__name__ )
1726
1827FIELD_TYPES = {
2635}
2736
2837
38+ def handle_errors (
39+ f : Callable [_P , Coroutine [None , None , _T ]]
40+ ) -> Callable [_P , Coroutine [None , None , _T ]]:
41+ async def inner (* args : _P .args , ** kwargs : _P .kwargs ) -> _T :
42+ try :
43+ return await f (* args , ** kwargs )
44+ except sa .exc .IntegrityError as e :
45+ raise web .HTTPBadRequest (reason = e .args [0 ])
46+ except sa .exc .NoResultFound :
47+ logger .warning ("No result found (%s)" , args , exc_info = True )
48+ raise web .HTTPNotFound ()
49+ except sa .exc .CompileError as e :
50+ logger .warning ("CompileError (%s)" , args , exc_info = True )
51+ raise web .HTTPBadRequest (reason = str (e ))
52+ return inner
53+
54+
2955def create_filters (columns : sa .ColumnCollection [str , sa .Column [object ]],
3056 filters : dict [str , object ]) -> Iterator [ExpressionElementRole [Any ]]:
3157 return (columns [k ].in_ (v ) if isinstance (v , list )
@@ -95,6 +121,7 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, Type[Declara
95121
96122 super ().__init__ ()
97123
124+ @handle_errors
98125 async def get_list (self , params : GetListParams ) -> tuple [list [Record ], int ]:
99126 per_page = params ["pagination" ]["perPage" ]
100127 offset = (params ["pagination" ]["page" ] - 1 ) * per_page
@@ -115,16 +142,14 @@ async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:
115142
116143 return entities , count
117144
145+ @handle_errors
118146 async def get_one (self , params : GetOneParams ) -> Record :
119147 async with self ._db .connect () as conn :
120148 stmt = sa .select (self ._table ).where (self ._table .c [self .primary_key ] == params ["id" ])
121149 result = await conn .execute (stmt )
122- try :
123- return result .one ()._asdict ()
124- except sa .exc .NoResultFound :
125- logger .warning ("No result found (%s)" , params ["id" ], exc_info = True )
126- raise web .HTTPNotFound ()
150+ return result .one ()._asdict ()
127151
152+ @handle_errors
128153 async def get_many (self , params : GetManyParams ) -> list [Record ]:
129154 async with self ._db .connect () as conn :
130155 stmt = sa .select (self ._table ).where (self ._table .c [self .primary_key ].in_ (params ["ids" ]))
@@ -134,6 +159,7 @@ async def get_many(self, params: GetManyParams) -> list[Record]:
134159 return records
135160 raise web .HTTPNotFound ()
136161
162+ @handle_errors
137163 async def create (self , params : CreateParams ) -> Record :
138164 async with self ._db .begin () as conn :
139165 stmt = sa .insert (self ._table ).values (params ["data" ]).returning (* self ._table .c )
@@ -144,44 +170,32 @@ async def create(self, params: CreateParams) -> Record:
144170 raise web .HTTPBadRequest (reason = "Integrity error (element already exists?)" )
145171 return row .one ()._asdict ()
146172
173+ @handle_errors
147174 async def update (self , params : UpdateParams ) -> Record :
148175 async with self ._db .begin () as conn :
149176 stmt = sa .update (self ._table ).where (self ._table .c [self .primary_key ] == params ["id" ])
150177 stmt = stmt .values (params ["data" ]).returning (* self ._table .c )
151- try :
152- row = await conn .execute (stmt )
153- except sa .exc .CompileError as e :
154- logger .warning ("CompileError (%s)" , params ["id" ], exc_info = True )
155- raise web .HTTPBadRequest (reason = str (e ))
156- try :
157- return row .one ()._asdict ()
158- except sa .exc .NoResultFound :
159- logger .warning ("No result found (%s)" , params ["id" ], exc_info = True )
160- raise web .HTTPNotFound ()
178+ row = await conn .execute (stmt )
179+ return row .one ()._asdict ()
161180
181+ @handle_errors
162182 async def update_many (self , params : UpdateManyParams ) -> list [Union [str , int ]]:
163183 async with self ._db .begin () as conn :
164184 stmt = sa .update (self ._table ).where (self ._table .c [self .primary_key ].in_ (params ["ids" ]))
165185 stmt = stmt .values (params ["data" ]).returning (self ._table .c [self .primary_key ])
166- try :
167- r = await conn .scalars (stmt )
168- except sa .exc .CompileError as e :
169- logger .warning ("CompileError (%s)" , params ["ids" ], exc_info = True )
170- raise web .HTTPBadRequest (reason = str (e ))
186+ r = await conn .scalars (stmt )
171187 # The security check has already called get_many(), so we can be sure
172188 # there will be results here.
173189 return list (r )
174190
191+ @handle_errors
175192 async def delete (self , params : DeleteParams ) -> Record :
176193 async with self ._db .begin () as conn :
177194 stmt = sa .delete (self ._table ).where (self ._table .c [self .primary_key ] == params ["id" ])
178195 row = await conn .execute (stmt .returning (* self ._table .c ))
179- try :
180- return row .one ()._asdict ()
181- except sa .exc .NoResultFound :
182- logger .warning ("No result found (%s)" , params ["id" ], exc_info = True )
183- raise web .HTTPNotFound ()
196+ return row .one ()._asdict ()
184197
198+ @handle_errors
185199 async def delete_many (self , params : DeleteManyParams ) -> list [Union [str , int ]]:
186200 async with self ._db .begin () as conn :
187201 stmt = sa .delete (self ._table ).where (self ._table .c [self .primary_key ].in_ (params ["ids" ]))
0 commit comments