Skip to content

Commit b9c7335

Browse files
Support non-id primary keys (#679)
1 parent 9797727 commit b9c7335

File tree

4 files changed

+151
-21
lines changed

4 files changed

+151
-21
lines changed

aiohttp_admin/backends/abc.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import warnings
34
from abc import ABC, abstractmethod
45
from datetime import datetime
56
from enum import Enum
@@ -85,7 +86,11 @@ class AbstractAdminResource(ABC):
8586
name: str
8687
fields: dict[str, FieldState]
8788
inputs: dict[str, InputState]
88-
repr_field: str
89+
primary_key: str
90+
91+
def __init__(self) -> None:
92+
if "id" in self.fields and self.primary_key != "id":
93+
warnings.warn("A non-PK 'id' column is likely to break the admin.", stacklevel=2)
8994

9095
async def filter_by_permissions(self, request: web.Request, perm_type: str,
9196
record: Record, original: Optional[Record] = None) -> Record:
@@ -132,6 +137,10 @@ async def _get_list(self, request: web.Request) -> web.Response:
132137
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
133138
query = parse_obj_as(GetListParams, request.query)
134139

140+
# When sort order refers to "id", this should be translated to primary key.
141+
if query["sort"]["field"] == "id":
142+
query["sort"]["field"] = self.primary_key
143+
135144
# Add filters from advanced permissions.
136145
# The permissions will be cached on the request from a previous permissions check.
137146
permissions = permissions_as_dict(request["aiohttpadmin_permissions"])
@@ -144,6 +153,9 @@ async def _get_list(self, request: web.Request) -> web.Response:
144153
results = [await self.filter_by_permissions(request, "view", r) for r in results]
145154
results = [r for r in results if await permits(request, f"admin.{self.name}.view",
146155
context=(request, r))]
156+
# We need to set "id" for react-admin (in case there is no "id" primary key).
157+
for r in results:
158+
r["id"] = r[self.primary_key]
147159
return json_response({"data": results, "total": total})
148160

149161
async def _get_one(self, request: web.Request) -> web.Response:
@@ -154,6 +166,7 @@ async def _get_one(self, request: web.Request) -> web.Response:
154166
if not await permits(request, f"admin.{self.name}.view", context=(request, result)):
155167
raise web.HTTPForbidden()
156168
result = await self.filter_by_permissions(request, "view", result)
169+
result["id"] = result[self.primary_key]
157170
return json_response({"data": result})
158171

159172
async def _get_many(self, request: web.Request) -> web.Response:
@@ -163,6 +176,8 @@ async def _get_many(self, request: web.Request) -> web.Response:
163176
results = await self.get_many(query)
164177
results = [await self.filter_by_permissions(request, "view", r) for r in results
165178
if await permits(request, f"admin.{self.name}.view", context=(request, r))]
179+
for r in results:
180+
r["id"] = r[self.primary_key]
166181
return json_response({"data": results})
167182

168183
async def _create(self, request: web.Request) -> web.Response:
@@ -175,12 +190,16 @@ async def _create(self, request: web.Request) -> web.Response:
175190

176191
result = await self.create(query)
177192
result = await self.filter_by_permissions(request, "view", result)
193+
result["id"] = result[self.primary_key]
178194
return json_response({"data": result})
179195

180196
async def _update(self, request: web.Request) -> web.Response:
181197
await check_permission(request, f"admin.{self.name}.edit", context=(request, None))
182198
query = parse_obj_as(UpdateParams, request.query)
183199

200+
if self.primary_key != "id":
201+
query["data"].pop("id", None)
202+
184203
# Check original record is allowed by permission filters.
185204
original = await self.get_one({"id": query["id"]})
186205
if not await permits(request, f"admin.{self.name}.edit", context=(request, original)):
@@ -199,6 +218,7 @@ async def _update(self, request: web.Request) -> web.Response:
199218

200219
result = await self.update(query)
201220
result = await self.filter_by_permissions(request, "view", result)
221+
result["id"] = result[self.primary_key]
202222
return json_response({"data": result})
203223

204224
async def _update_many(self, request: web.Request) -> web.Response:
@@ -230,6 +250,7 @@ async def _delete(self, request: web.Request) -> web.Response:
230250

231251
result = await self.delete(query)
232252
result = await self.filter_by_permissions(request, "view", result)
253+
result["id"] = result[self.primary_key]
233254
return json_response({"data": result})
234255

235256
async def _delete_many(self, request: web.Request) -> web.Response:

aiohttp_admin/backends/sqlalchemy.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, Type[Declara
8282
self._db = db
8383
self._table = table
8484

85-
self._primary_key = tuple(filter(lambda c: table.c[c].primary_key, self._table.c.keys()))
86-
if not self._primary_key:
85+
pk = tuple(filter(lambda c: table.c[c].primary_key, self._table.c.keys()))
86+
if not pk:
8787
raise ValueError("No primary key found.")
88-
if len(self._primary_key) > 1:
88+
if len(pk) > 1:
8989
# TODO: Test composite primary key
9090
raise NotImplementedError("Composite keys not supported yet.")
91-
self.repr_field = self._primary_key[0]
91+
self.primary_key = pk[0]
92+
93+
super().__init__()
9294

9395
async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:
9496
per_page = params["pagination"]["perPage"]
@@ -112,7 +114,7 @@ async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:
112114

113115
async def get_one(self, params: GetOneParams) -> Record:
114116
async with self._db.connect() as conn:
115-
stmt = sa.select(self._table).where(self._table.c["id"] == params["id"])
117+
stmt = sa.select(self._table).where(self._table.c[self.primary_key] == params["id"])
116118
result = await conn.execute(stmt)
117119
try:
118120
return result.one()._asdict()
@@ -122,8 +124,7 @@ async def get_one(self, params: GetOneParams) -> Record:
122124

123125
async def get_many(self, params: GetManyParams) -> list[Record]:
124126
async with self._db.connect() as conn:
125-
# TODO: Handle primary key not called "id"
126-
stmt = sa.select(self._table).where(self._table.c["id"].in_(params["ids"]))
127+
stmt = sa.select(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
127128
result = await conn.execute(stmt)
128129
records = [r._asdict() for r in result]
129130
if records:
@@ -142,7 +143,7 @@ async def create(self, params: CreateParams) -> Record:
142143

143144
async def update(self, params: UpdateParams) -> Record:
144145
async with self._db.begin() as conn:
145-
stmt = sa.update(self._table).where(self._table.c["id"] == params["id"])
146+
stmt = sa.update(self._table).where(self._table.c[self.primary_key] == params["id"])
146147
stmt = stmt.values(params["data"]).returning(*self._table.c)
147148
try:
148149
row = await conn.execute(stmt)
@@ -157,8 +158,8 @@ async def update(self, params: UpdateParams) -> Record:
157158

158159
async def update_many(self, params: UpdateManyParams) -> list[Union[str, int]]:
159160
async with self._db.begin() as conn:
160-
stmt = sa.update(self._table).where(self._table.c["id"].in_(params["ids"]))
161-
stmt = stmt.values(params["data"]).returning(self._table.c["id"])
161+
stmt = sa.update(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
162+
stmt = stmt.values(params["data"]).returning(self._table.c[self.primary_key])
162163
try:
163164
r = await conn.scalars(stmt)
164165
except sa.exc.CompileError as e:
@@ -170,7 +171,7 @@ async def update_many(self, params: UpdateManyParams) -> list[Union[str, int]]:
170171

171172
async def delete(self, params: DeleteParams) -> Record:
172173
async with self._db.begin() as conn:
173-
stmt = sa.delete(self._table).where(self._table.c["id"] == params["id"])
174+
stmt = sa.delete(self._table).where(self._table.c[self.primary_key] == params["id"])
174175
row = await conn.execute(stmt.returning(*self._table.c))
175176
try:
176177
return row.one()._asdict()
@@ -180,9 +181,8 @@ async def delete(self, params: DeleteParams) -> Record:
180181

181182
async def delete_many(self, params: DeleteManyParams) -> list[Union[str, int]]:
182183
async with self._db.begin() as conn:
183-
# TODO: Handle primary key not called "id"
184-
stmt = sa.delete(self._table).where(self._table.c["id"].in_(params["ids"]))
185-
r = await conn.scalars(stmt.returning(self._table.c["id"]))
184+
stmt = sa.delete(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
185+
r = await conn.scalars(stmt.returning(self._table.c[self.primary_key]))
186186
ids = list(r)
187187
if ids:
188188
return ids

aiohttp_admin/routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def setup_resources(admin: web.Application, schema: Schema) -> None:
2525
if not all(f in m.fields for f in display_fields):
2626
raise ValueError(f"Display includes non-existent field {display_fields}")
2727

28-
repr_field = r.get("repr", m.repr_field)
28+
repr_field = r.get("repr", m.primary_key)
2929

3030
for k, v in m.inputs.items():
3131
if k in display_fields:

tests/test_backends_sqlalchemy.py

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
from typing import Type
1+
import json
2+
from typing import Awaitable, Callable, Type
23

4+
import pytest
35
import sqlalchemy as sa
4-
from sqlalchemy.ext.asyncio import AsyncEngine
6+
from aiohttp import web
7+
from aiohttp.test_utils import TestClient
8+
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
59
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
610

11+
import aiohttp_admin
12+
from _auth import check_credentials
713
from aiohttp_admin.backends.sqlalchemy import SAResource
814

15+
_Login = Callable[[TestClient], Awaitable[dict[str, str]]]
16+
917

1018
def test_pk(base: Type[DeclarativeBase], mock_engine: AsyncEngine) -> None:
1119
class TestModel(base): # type: ignore[misc,valid-type]
@@ -15,7 +23,7 @@ class TestModel(base): # type: ignore[misc,valid-type]
1523

1624
r = SAResource(mock_engine, TestModel)
1725
assert r.name == "dummy"
18-
assert r.repr_field == "id"
26+
assert r.primary_key == "id"
1927
assert r.fields == {
2028
"id": {"type": "NumberField", "props": {}},
2129
"num": {"type": "TextField", "props": {}}
@@ -34,7 +42,7 @@ def test_table(mock_engine: AsyncEngine) -> None:
3442

3543
r = SAResource(mock_engine, dummy_table)
3644
assert r.name == "dummy"
37-
assert r.repr_field == "id"
45+
assert r.primary_key == "id"
3846
assert r.fields == {
3947
"id": {"type": "NumberField", "props": {}},
4048
"num": {"type": "TextField", "props": {}}
@@ -57,7 +65,7 @@ class TestChildModel(base): # type: ignore[misc,valid-type]
5765

5866
r = SAResource(mock_engine, TestChildModel)
5967
assert r.name == "child"
60-
assert r.repr_field == "id"
68+
assert r.primary_key == "id"
6169
assert r.fields == {"id": {"type": "ReferenceField", "props": {"reference": "dummy"}}}
6270
# PK with FK constraint should be shown in create form.
6371
assert r.inputs == {"id": {
@@ -82,3 +90,104 @@ class TestOne(base): # type: ignore[misc,valid-type]
8290
"props": {"children": {"id": {"props": {}, "type": "NumberField"}},
8391
"label": "Ones", "reference": "one", "source": "id", "target": "many_id"}}
8492
assert "ones" not in r.inputs
93+
94+
95+
async def test_nonid_pk(base: Type[DeclarativeBase], mock_engine: AsyncEngine) -> None:
96+
class TestModel(base): # type: ignore[misc,valid-type]
97+
__tablename__ = "test"
98+
num: Mapped[int] = mapped_column(primary_key=True)
99+
other: Mapped[str]
100+
101+
r = SAResource(mock_engine, TestModel)
102+
assert r.name == "test"
103+
assert r.primary_key == "num"
104+
assert r.fields == {
105+
"num": {"type": "NumberField", "props": {}},
106+
"other": {"type": "TextField", "props": {}}
107+
}
108+
assert r.inputs == {
109+
"num": {"type": "NumberInput", "show_create": False, "props": {}},
110+
"other": {"type": "TextInput", "show_create": True, "props": {}}
111+
}
112+
113+
114+
async def test_id_nonpk(base: Type[DeclarativeBase], mock_engine: AsyncEngine) -> None:
115+
class NotPK(base): # type: ignore[misc,valid-type]
116+
__tablename__ = "notpk"
117+
name: Mapped[str] = mapped_column(primary_key=True)
118+
id: Mapped[int]
119+
120+
class CompositePK(base): # type: ignore[misc,valid-type]
121+
__tablename__ = "compound"
122+
id: Mapped[int] = mapped_column(primary_key=True)
123+
other: Mapped[int] = mapped_column(primary_key=True)
124+
125+
with pytest.warns(UserWarning, match="A non-PK 'id' column is likely to break the admin."):
126+
SAResource(mock_engine, NotPK)
127+
# TODO: Support composite PK.
128+
# with pytest.warns(UserWarning, match="'id' column in a composite PK is likely to"
129+
# + " break the admin"):
130+
# SAResource(mock_engine, CompositePK)
131+
132+
133+
async def test_nonid_pk_api(
134+
base: DeclarativeBase, aiohttp_client: Callable[[web.Application], Awaitable[TestClient]],
135+
login: _Login
136+
) -> None:
137+
class TestModel(base): # type: ignore[misc,valid-type]
138+
__tablename__ = "test"
139+
num: Mapped[int] = mapped_column(primary_key=True)
140+
other: Mapped[str]
141+
142+
app = web.Application()
143+
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
144+
db = async_sessionmaker(engine, expire_on_commit=False)
145+
async with engine.begin() as conn:
146+
await conn.run_sync(base.metadata.create_all)
147+
async with db.begin() as sess:
148+
sess.add(TestModel(num=5, other="foo"))
149+
sess.add(TestModel(num=8, other="bar"))
150+
151+
schema: aiohttp_admin.Schema = {
152+
"security": {
153+
"check_credentials": check_credentials,
154+
"secure": False
155+
},
156+
"resources": ({"model": SAResource(engine, TestModel)},)
157+
}
158+
app["admin"] = aiohttp_admin.setup(app, schema)
159+
160+
admin_client = await aiohttp_client(app)
161+
assert admin_client.app
162+
h = await login(admin_client)
163+
164+
url = app["admin"].router["test_get_list"].url_for()
165+
p = {"pagination": json.dumps({"page": 1, "perPage": 10}),
166+
"sort": json.dumps({"field": "id", "order": "DESC"}), "filter": "{}"}
167+
async with admin_client.get(url, params=p, headers=h) as resp:
168+
assert resp.status == 200
169+
assert await resp.json() == {"data": [{"id": 8, "num": 8, "other": "bar"},
170+
{"id": 5, "num": 5, "other": "foo"}], "total": 2}
171+
172+
url = app["admin"].router["test_get_one"].url_for()
173+
async with admin_client.get(url, params={"id": 8}, headers=h) as resp:
174+
assert resp.status == 200
175+
assert await resp.json() == {"data": {"id": 8, "num": 8, "other": "bar"}}
176+
177+
url = app["admin"].router["test_get_many"].url_for()
178+
async with admin_client.get(url, params={"ids": "[5, 8]"}, headers=h) as resp:
179+
assert resp.status == 200
180+
assert await resp.json() == {"data": [{"id": 5, "num": 5, "other": "foo"},
181+
{"id": 8, "num": 8, "other": "bar"}]}
182+
183+
url = app["admin"].router["test_create"].url_for()
184+
p = {"data": json.dumps({"num": 12, "other": "this"})}
185+
async with admin_client.post(url, params=p, headers=h) as resp:
186+
assert resp.status == 200
187+
assert await resp.json() == {"data": {"id": 12, "num": 12, "other": "this"}}
188+
189+
url = app["admin"].router["test_update"].url_for()
190+
p1 = {"id": 5, "data": json.dumps({"id": 5, "other": "that"}), "previousData": "{}"}
191+
async with admin_client.put(url, params=p1, headers=h) as resp:
192+
assert resp.status == 200
193+
assert await resp.json() == {"data": {"id": 5, "num": 5, "other": "that"}}

0 commit comments

Comments
 (0)