Skip to content

Commit a6eef5b

Browse files
Fix record type checking (#770)
1 parent a86812b commit a6eef5b

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed

aiohttp_admin/backends/abc.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import sys
44
import warnings
55
from abc import ABC, abstractmethod
6-
from collections.abc import Sequence
76
from datetime import date, datetime, time
87
from enum import Enum
98
from functools import cached_property, partial
@@ -112,13 +111,14 @@ class AbstractAdminResource(ABC):
112111
primary_key: str
113112
omit_fields: set[str]
114113

115-
def __init__(self) -> None:
114+
def __init__(self, record_type: Optional[dict[str, TypeAlias]] = None) -> None:
116115
if "id" in self.fields and self.primary_key != "id":
117116
warnings.warn("A non-PK 'id' column is likely to break the admin.", stacklevel=2)
118117

119-
d = {k: self._get_input_type(v) for k, v in self.inputs.items()}
120118
# For runtime type checking only.
121-
self._record_type = TypedDict("RecordType", d, total=False) # type: ignore[misc]
119+
if record_type is None:
120+
record_type = {k: Any for k in self.inputs}
121+
self._record_type = TypedDict("RecordType", record_type, total=False) # type: ignore[misc]
122122

123123
async def filter_by_permissions(self, request: web.Request, perm_type: str,
124124
record: Record, original: Optional[Record] = None) -> Record:
@@ -319,12 +319,6 @@ async def _delete_many(self, request: web.Request) -> web.Response:
319319
raise web.HTTPNotFound()
320320
return json_response({"data": ids})
321321

322-
def _get_input_type(self, inp: InputState) -> TypeAlias:
323-
t = INPUT_TYPES.get(inp["type"], str)
324-
validators = inp.get("props", {}).get("validate", ())
325-
assert isinstance(validators, Sequence) # noqa: S101
326-
return t if any(v["name"] == "required" for v in validators) else Optional[t]
327-
328322
@cached_property
329323
def routes(self) -> tuple[web.RouteDef, ...]:
330324
"""Routes to act on this resource.

aiohttp_admin/backends/sqlalchemy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, type[Declara
168168
self.fields = {}
169169
self.inputs = {}
170170
self.omit_fields = set()
171+
record_type = {}
171172
for c in table.c.values():
172173
if c.foreign_keys:
173174
field = "ReferenceField"
@@ -203,6 +204,10 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, type[Declara
203204
props["validate"] = self._get_validators(table, c)
204205
self.inputs[c.name] = comp(inp, props) # type: ignore[assignment]
205206
self.inputs[c.name]["show_create"] = show
207+
field_type: Any = c.type.python_type
208+
if c.nullable:
209+
field_type = Optional[field_type]
210+
record_type[c.name] = field_type
206211

207212
if not isinstance(model_or_table, sa.Table):
208213
# Append fields to represent ORM relationships.
@@ -257,7 +262,7 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, type[Declara
257262
raise NotImplementedError("Composite keys not supported yet.")
258263
self.primary_key = pk[0]
259264

260-
super().__init__()
265+
super().__init__(record_type)
261266

262267
@handle_errors
263268
async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:

tests/test_backends_sqlalchemy.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from collections.abc import Awaitable, Callable
33
from datetime import date, datetime
4-
from typing import Union
4+
from typing import Optional, Union
55

66
import pytest
77
import sqlalchemy as sa
@@ -439,3 +439,56 @@ class Wrong(base): # type: ignore[misc,valid-type]
439439

440440
with pytest.raises(ValueError, match="not an attribute"):
441441
permission_for(M, filters={Wrong.id: 1})
442+
443+
444+
async def test_record_type(
445+
base: DeclarativeBase, aiohttp_client: Callable[[web.Application], Awaitable[TestClient]],
446+
login: _Login
447+
) -> None:
448+
class TestModel(base): # type: ignore[misc,valid-type]
449+
__tablename__ = "test"
450+
id: Mapped[int] = mapped_column(primary_key=True)
451+
foo: Mapped[Optional[bool]]
452+
bar: Mapped[int]
453+
454+
app = web.Application()
455+
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
456+
async with engine.begin() as conn:
457+
await conn.run_sync(base.metadata.create_all)
458+
459+
schema: aiohttp_admin.Schema = {
460+
"security": {
461+
"check_credentials": check_credentials,
462+
"secure": False
463+
},
464+
"resources": ({"model": SAResource(engine, TestModel)},)
465+
}
466+
app["admin"] = aiohttp_admin.setup(app, schema)
467+
468+
admin_client = await aiohttp_client(app)
469+
assert admin_client.app
470+
h = await login(admin_client)
471+
472+
url = app["admin"].router["test_create"].url_for()
473+
p = {"data": json.dumps({"foo": True, "bar": 5})}
474+
async with admin_client.post(url, params=p, headers=h) as resp:
475+
assert resp.status == 200
476+
assert await resp.json() == {"data": {"id": 1, "foo": True, "bar": 5}}
477+
p = {"data": json.dumps({"foo": None, "bar": -1})}
478+
async with admin_client.post(url, params=p, headers=h) as resp:
479+
assert resp.status == 200
480+
assert await resp.json() == {"data": {"id": 2, "foo": None, "bar": -1}}
481+
482+
p = {"data": json.dumps({"foo": 5, "bar": "foo"})}
483+
async with admin_client.post(url, params=p, headers=h) as resp:
484+
assert resp.status == 400
485+
errors = await resp.json()
486+
assert any(e["loc"] == ["foo"] and e["type"] == "bool_parsing" for e in errors)
487+
assert any(e["loc"] == ["bar"] and e["type"] == "int_parsing" for e in errors)
488+
489+
p = {"data": json.dumps({"foo": "foo", "bar": None})}
490+
async with admin_client.post(url, params=p, headers=h) as resp:
491+
assert resp.status == 400
492+
errors = await resp.json()
493+
assert any(e["loc"] == ["foo"] and e["type"] == "bool_parsing" for e in errors)
494+
assert any(e["loc"] == ["bar"] and e["type"] == "int_type" for e in errors)

0 commit comments

Comments
 (0)