2
2
import json
3
3
import warnings
4
4
from abc import ABC , abstractmethod
5
- from datetime import datetime
5
+ from datetime import date , datetime
6
6
from enum import Enum
7
7
from functools import cached_property , partial
8
+ from types import MappingProxyType
8
9
from typing import Any , Literal , Optional , TypedDict , Union
9
10
10
11
from aiohttp import web
16
17
17
18
Record = dict [str , object ]
18
19
20
+ INPUT_TYPES = MappingProxyType ({
21
+ "BooleanInput" : bool ,
22
+ "DateInput" : date ,
23
+ "DateTimeInput" : datetime ,
24
+ "NumberInput" : float
25
+ })
26
+
19
27
20
28
class Encoder (json .JSONEncoder ):
21
29
def default (self , o : object ) -> Any :
22
- if isinstance (o , datetime ):
30
+ if isinstance (o , date ):
23
31
return str (o )
24
32
if isinstance (o , Enum ):
25
33
return o .value
@@ -92,6 +100,10 @@ def __init__(self) -> None:
92
100
if "id" in self .fields and self .primary_key != "id" :
93
101
warnings .warn ("A non-PK 'id' column is likely to break the admin." , stacklevel = 2 )
94
102
103
+ d = {k : INPUT_TYPES .get (v ["type" ], str ) for k , v in self .inputs .items ()}
104
+ # For runtime type checking only.
105
+ self ._record_type = TypedDict ("RecordType" , d , total = False ) # type: ignore[misc]
106
+
95
107
async def filter_by_permissions (self , request : web .Request , perm_type : str ,
96
108
record : Record , original : Optional [Record ] = None ) -> Record :
97
109
"""Return a filtered record containing permissible fields only."""
@@ -182,6 +194,11 @@ async def _get_many(self, request: web.Request) -> web.Response:
182
194
183
195
async def _create (self , request : web .Request ) -> web .Response :
184
196
query = parse_obj_as (CreateParams , request .query )
197
+ # TODO(Pydantic): Dissallow extra arguments
198
+ for k in query ["data" ]:
199
+ if k not in self .inputs and k != "id" :
200
+ raise web .HTTPBadRequest (reason = f"Invalid field '{ k } '" )
201
+ query ["data" ] = parse_obj_as (self ._record_type , query ["data" ])
185
202
await check_permission (request , f"admin.{ self .name } .add" , context = (request , query ["data" ]))
186
203
for k , v in query ["data" ].items ():
187
204
if v is not None :
@@ -196,6 +213,12 @@ async def _create(self, request: web.Request) -> web.Response:
196
213
async def _update (self , request : web .Request ) -> web .Response :
197
214
await check_permission (request , f"admin.{ self .name } .edit" , context = (request , None ))
198
215
query = parse_obj_as (UpdateParams , request .query )
216
+ # TODO(Pydantic): Dissallow extra arguments
217
+ for k in query ["data" ]:
218
+ if k not in self .inputs and k != "id" :
219
+ raise web .HTTPBadRequest (reason = f"Invalid field '{ k } '" )
220
+ query ["data" ] = parse_obj_as (self ._record_type , query ["data" ])
221
+ query ["previousData" ] = parse_obj_as (self ._record_type , query ["previousData" ])
199
222
200
223
if self .primary_key != "id" :
201
224
query ["data" ].pop ("id" , None )
@@ -224,6 +247,11 @@ async def _update(self, request: web.Request) -> web.Response:
224
247
async def _update_many (self , request : web .Request ) -> web .Response :
225
248
await check_permission (request , f"admin.{ self .name } .edit" , context = (request , None ))
226
249
query = parse_obj_as (UpdateManyParams , request .query )
250
+ # TODO(Pydantic): Dissallow extra arguments
251
+ for k in query ["data" ]:
252
+ if k not in self .inputs and k != "id" :
253
+ raise web .HTTPBadRequest (reason = f"Invalid field '{ k } '" )
254
+ query ["data" ] = parse_obj_as (self ._record_type , query ["data" ])
227
255
228
256
# Check original records are allowed by permission filters.
229
257
originals = await self .get_many ({"ids" : query ["ids" ]})
@@ -243,6 +271,7 @@ async def _update_many(self, request: web.Request) -> web.Response:
243
271
async def _delete (self , request : web .Request ) -> web .Response :
244
272
await check_permission (request , f"admin.{ self .name } .delete" , context = (request , None ))
245
273
query = parse_obj_as (DeleteParams , request .query )
274
+ query ["previousData" ] = parse_obj_as (self ._record_type , query ["previousData" ])
246
275
247
276
original = await self .get_one ({"id" : query ["id" ]})
248
277
if not await permits (request , f"admin.{ self .name } .delete" , context = (request , original )):
0 commit comments