1
1
import asyncio
2
2
import logging
3
3
import operator
4
- from typing import Any , Iterator , Type , Union
4
+ import sys
5
+ from typing import Any , Callable , Coroutine , Iterator , Type , TypeVar , Union
5
6
6
7
import sqlalchemy as sa
7
8
from aiohttp import web
13
14
AbstractAdminResource , CreateParams , DeleteManyParams , DeleteParams , GetListParams ,
14
15
GetManyParams , GetOneParams , Record , UpdateManyParams , UpdateParams )
15
16
17
+ if sys .version_info >= (3 , 10 ):
18
+ from typing import ParamSpec
19
+ else :
20
+ from typing_extensions import ParamSpec
21
+
22
+ _P = ParamSpec ("_P" )
23
+ _T = TypeVar ("_T" )
24
+
16
25
logger = logging .getLogger (__name__ )
17
26
18
27
FIELD_TYPES = {
26
35
}
27
36
28
37
38
+ def handle_errors (
39
+ f : Callable [_P , Coroutine [None , None , _T ]]
40
+ ) -> Callable [_P , Coroutine [None , None , _T ]]:
41
+ async def inner (* args : _P .args , ** kwargs : _P .kwargs ) -> _T :
42
+ try :
43
+ return await f (* args , ** kwargs )
44
+ except sa .exc .IntegrityError as e :
45
+ raise web .HTTPBadRequest (reason = e .args [0 ])
46
+ except sa .exc .NoResultFound :
47
+ logger .warning ("No result found (%s)" , args , exc_info = True )
48
+ raise web .HTTPNotFound ()
49
+ except sa .exc .CompileError as e :
50
+ logger .warning ("CompileError (%s)" , args , exc_info = True )
51
+ raise web .HTTPBadRequest (reason = str (e ))
52
+ return inner
53
+
54
+
29
55
def create_filters (columns : sa .ColumnCollection [str , sa .Column [object ]],
30
56
filters : dict [str , object ]) -> Iterator [ExpressionElementRole [Any ]]:
31
57
return (columns [k ].in_ (v ) if isinstance (v , list )
@@ -95,6 +121,7 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, Type[Declara
95
121
96
122
super ().__init__ ()
97
123
124
+ @handle_errors
98
125
async def get_list (self , params : GetListParams ) -> tuple [list [Record ], int ]:
99
126
per_page = params ["pagination" ]["perPage" ]
100
127
offset = (params ["pagination" ]["page" ] - 1 ) * per_page
@@ -115,16 +142,14 @@ async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:
115
142
116
143
return entities , count
117
144
145
+ @handle_errors
118
146
async def get_one (self , params : GetOneParams ) -> Record :
119
147
async with self ._db .connect () as conn :
120
148
stmt = sa .select (self ._table ).where (self ._table .c [self .primary_key ] == params ["id" ])
121
149
result = await conn .execute (stmt )
122
- try :
123
- return result .one ()._asdict ()
124
- except sa .exc .NoResultFound :
125
- logger .warning ("No result found (%s)" , params ["id" ], exc_info = True )
126
- raise web .HTTPNotFound ()
150
+ return result .one ()._asdict ()
127
151
152
+ @handle_errors
128
153
async def get_many (self , params : GetManyParams ) -> list [Record ]:
129
154
async with self ._db .connect () as conn :
130
155
stmt = sa .select (self ._table ).where (self ._table .c [self .primary_key ].in_ (params ["ids" ]))
@@ -134,6 +159,7 @@ async def get_many(self, params: GetManyParams) -> list[Record]:
134
159
return records
135
160
raise web .HTTPNotFound ()
136
161
162
+ @handle_errors
137
163
async def create (self , params : CreateParams ) -> Record :
138
164
async with self ._db .begin () as conn :
139
165
stmt = sa .insert (self ._table ).values (params ["data" ]).returning (* self ._table .c )
@@ -144,44 +170,32 @@ async def create(self, params: CreateParams) -> Record:
144
170
raise web .HTTPBadRequest (reason = "Integrity error (element already exists?)" )
145
171
return row .one ()._asdict ()
146
172
173
+ @handle_errors
147
174
async def update (self , params : UpdateParams ) -> Record :
148
175
async with self ._db .begin () as conn :
149
176
stmt = sa .update (self ._table ).where (self ._table .c [self .primary_key ] == params ["id" ])
150
177
stmt = stmt .values (params ["data" ]).returning (* self ._table .c )
151
- try :
152
- row = await conn .execute (stmt )
153
- except sa .exc .CompileError as e :
154
- logger .warning ("CompileError (%s)" , params ["id" ], exc_info = True )
155
- raise web .HTTPBadRequest (reason = str (e ))
156
- try :
157
- return row .one ()._asdict ()
158
- except sa .exc .NoResultFound :
159
- logger .warning ("No result found (%s)" , params ["id" ], exc_info = True )
160
- raise web .HTTPNotFound ()
178
+ row = await conn .execute (stmt )
179
+ return row .one ()._asdict ()
161
180
181
+ @handle_errors
162
182
async def update_many (self , params : UpdateManyParams ) -> list [Union [str , int ]]:
163
183
async with self ._db .begin () as conn :
164
184
stmt = sa .update (self ._table ).where (self ._table .c [self .primary_key ].in_ (params ["ids" ]))
165
185
stmt = stmt .values (params ["data" ]).returning (self ._table .c [self .primary_key ])
166
- try :
167
- r = await conn .scalars (stmt )
168
- except sa .exc .CompileError as e :
169
- logger .warning ("CompileError (%s)" , params ["ids" ], exc_info = True )
170
- raise web .HTTPBadRequest (reason = str (e ))
186
+ r = await conn .scalars (stmt )
171
187
# The security check has already called get_many(), so we can be sure
172
188
# there will be results here.
173
189
return list (r )
174
190
191
+ @handle_errors
175
192
async def delete (self , params : DeleteParams ) -> Record :
176
193
async with self ._db .begin () as conn :
177
194
stmt = sa .delete (self ._table ).where (self ._table .c [self .primary_key ] == params ["id" ])
178
195
row = await conn .execute (stmt .returning (* self ._table .c ))
179
- try :
180
- return row .one ()._asdict ()
181
- except sa .exc .NoResultFound :
182
- logger .warning ("No result found (%s)" , params ["id" ], exc_info = True )
183
- raise web .HTTPNotFound ()
196
+ return row .one ()._asdict ()
184
197
198
+ @handle_errors
185
199
async def delete_many (self , params : DeleteManyParams ) -> list [Union [str , int ]]:
186
200
async with self ._db .begin () as conn :
187
201
stmt = sa .delete (self ._table ).where (self ._table .c [self .primary_key ].in_ (params ["ids" ]))
0 commit comments