22import json
33import warnings
44from abc import ABC , abstractmethod
5- from datetime import datetime
5+ from datetime import date , datetime
66from enum import Enum
77from functools import cached_property , partial
8+ from types import MappingProxyType
89from typing import Any , Literal , Optional , TypedDict , Union
910
1011from aiohttp import web
1617
1718Record = dict [str , object ]
1819
20+ INPUT_TYPES = MappingProxyType ({
21+ "BooleanInput" : bool ,
22+ "DateInput" : date ,
23+ "DateTimeInput" : datetime ,
24+ "NumberInput" : float
25+ })
26+
1927
2028class Encoder (json .JSONEncoder ):
2129 def default (self , o : object ) -> Any :
22- if isinstance (o , datetime ):
30+ if isinstance (o , date ):
2331 return str (o )
2432 if isinstance (o , Enum ):
2533 return o .value
@@ -92,6 +100,10 @@ def __init__(self) -> None:
92100 if "id" in self .fields and self .primary_key != "id" :
93101 warnings .warn ("A non-PK 'id' column is likely to break the admin." , stacklevel = 2 )
94102
103+ d = {k : INPUT_TYPES .get (v ["type" ], str ) for k , v in self .inputs .items ()}
104+ # For runtime type checking only.
105+ self ._record_type = TypedDict ("RecordType" , d , total = False ) # type: ignore[misc]
106+
95107 async def filter_by_permissions (self , request : web .Request , perm_type : str ,
96108 record : Record , original : Optional [Record ] = None ) -> Record :
97109 """Return a filtered record containing permissible fields only."""
@@ -182,6 +194,11 @@ async def _get_many(self, request: web.Request) -> web.Response:
182194
183195 async def _create (self , request : web .Request ) -> web .Response :
184196 query = parse_obj_as (CreateParams , request .query )
197+ # TODO(Pydantic): Dissallow extra arguments
198+ for k in query ["data" ]:
199+ if k not in self .inputs and k != "id" :
200+ raise web .HTTPBadRequest (reason = f"Invalid field '{ k } '" )
201+ query ["data" ] = parse_obj_as (self ._record_type , query ["data" ])
185202 await check_permission (request , f"admin.{ self .name } .add" , context = (request , query ["data" ]))
186203 for k , v in query ["data" ].items ():
187204 if v is not None :
@@ -196,6 +213,12 @@ async def _create(self, request: web.Request) -> web.Response:
196213 async def _update (self , request : web .Request ) -> web .Response :
197214 await check_permission (request , f"admin.{ self .name } .edit" , context = (request , None ))
198215 query = parse_obj_as (UpdateParams , request .query )
216+ # TODO(Pydantic): Dissallow extra arguments
217+ for k in query ["data" ]:
218+ if k not in self .inputs and k != "id" :
219+ raise web .HTTPBadRequest (reason = f"Invalid field '{ k } '" )
220+ query ["data" ] = parse_obj_as (self ._record_type , query ["data" ])
221+ query ["previousData" ] = parse_obj_as (self ._record_type , query ["previousData" ])
199222
200223 if self .primary_key != "id" :
201224 query ["data" ].pop ("id" , None )
@@ -224,6 +247,11 @@ async def _update(self, request: web.Request) -> web.Response:
224247 async def _update_many (self , request : web .Request ) -> web .Response :
225248 await check_permission (request , f"admin.{ self .name } .edit" , context = (request , None ))
226249 query = parse_obj_as (UpdateManyParams , request .query )
250+ # TODO(Pydantic): Dissallow extra arguments
251+ for k in query ["data" ]:
252+ if k not in self .inputs and k != "id" :
253+ raise web .HTTPBadRequest (reason = f"Invalid field '{ k } '" )
254+ query ["data" ] = parse_obj_as (self ._record_type , query ["data" ])
227255
228256 # Check original records are allowed by permission filters.
229257 originals = await self .get_many ({"ids" : query ["ids" ]})
@@ -243,6 +271,7 @@ async def _update_many(self, request: web.Request) -> web.Response:
243271 async def _delete (self , request : web .Request ) -> web .Response :
244272 await check_permission (request , f"admin.{ self .name } .delete" , context = (request , None ))
245273 query = parse_obj_as (DeleteParams , request .query )
274+ query ["previousData" ] = parse_obj_as (self ._record_type , query ["previousData" ])
246275
247276 original = await self .get_one ({"id" : query ["id" ]})
248277 if not await permits (request , f"admin.{ self .name } .delete" , context = (request , original )):
0 commit comments