Skip to content

Commit 0429e97

Browse files
Cache permissions during request (#657)
1 parent 12d83e3 commit 0429e97

File tree

3 files changed

+60
-32
lines changed

3 files changed

+60
-32
lines changed

aiohttp_admin/backends/abc.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, Literal, Optional, TypedDict, Union
88

99
from aiohttp import web
10-
from aiohttp_security import authorized_userid, check_permission, permits
10+
from aiohttp_security import check_permission, permits
1111
from pydantic import Json, parse_obj_as
1212

1313
from ..security import permissions_as_dict
@@ -87,7 +87,7 @@ async def filter_by_permissions(self, request: web.Request, perm_type: str,
8787
"""Return a filtered record containing permissible fields only."""
8888
return {k: v for k, v in record.items()
8989
if await permits(request, f"admin.{self.name}.{k}.{perm_type}",
90-
context=original or record)}
90+
context=(request, original or record))}
9191

9292
@abstractmethod
9393
async def get_list(self, params: GetListParams) -> tuple[list[Record], int]:
@@ -120,71 +120,69 @@ async def delete_many(self, params: DeleteManyParams) -> list[Union[int, str]]:
120120
# https://marmelab.com/react-admin/DataProviderWriting.html
121121

122122
async def _get_list(self, request: web.Request) -> web.Response:
123-
await check_permission(request, f"admin.{self.name}.view")
123+
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
124124
query = parse_obj_as(GetListParams, request.query)
125125

126126
# Add filters from advanced permissions.
127-
if request.app["identity_callback"]:
128-
identity = await authorized_userid(request)
129-
user_details = await request.app["identity_callback"](identity)
130-
permissions = permissions_as_dict(user_details["permissions"])
131-
filters = permissions.get(f"admin.{self.name}.view",
132-
permissions.get(f"admin.{self.name}.*", {}))
133-
for k, v in filters.items():
134-
query["filter"][k] = v
127+
# The permissions will be cached on the request from a previous permissions check.
128+
permissions = permissions_as_dict(request["aiohttpadmin_permissions"])
129+
filters = permissions.get(f"admin.{self.name}.view",
130+
permissions.get(f"admin.{self.name}.*", {}))
131+
for k, v in filters.items():
132+
query["filter"][k] = v
135133

136134
results, total = await self.get_list(query)
137135
results = [await self.filter_by_permissions(request, "view", r) for r in results]
138136
results = [r for r in results if await permits(request, f"admin.{self.name}.view",
139-
context=r)]
137+
context=(request, r))]
140138
return json_response({"data": results, "total": total})
141139

142140
async def _get_one(self, request: web.Request) -> web.Response:
143-
await check_permission(request, f"admin.{self.name}.view")
141+
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
144142
query = parse_obj_as(GetOneParams, request.query)
145143

146144
result = await self.get_one(query)
147-
if not await permits(request, f"admin.{self.name}.view", context=result):
145+
if not await permits(request, f"admin.{self.name}.view", context=(request, result)):
148146
raise web.HTTPForbidden()
149147
result = await self.filter_by_permissions(request, "view", result)
150148
return json_response({"data": result})
151149

152150
async def _get_many(self, request: web.Request) -> web.Response:
153-
await check_permission(request, f"admin.{self.name}.view")
151+
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
154152
query = parse_obj_as(GetManyParams, request.query)
155153

156154
results = await self.get_many(query)
157155
results = [await self.filter_by_permissions(request, "view", r) for r in results
158-
if await permits(request, f"admin.{self.name}.view", context=r)]
156+
if await permits(request, f"admin.{self.name}.view", context=(request, r))]
159157
return json_response({"data": results})
160158

161159
async def _create(self, request: web.Request) -> web.Response:
162160
query = parse_obj_as(CreateParams, request.query)
163-
await check_permission(request, f"admin.{self.name}.add", context=query["data"])
161+
await check_permission(request, f"admin.{self.name}.add", context=(request, query["data"]))
164162
for k, v in query["data"].items():
165163
if v is not None:
166164
await check_permission(request, f"admin.{self.name}.{k}.add",
167-
context=query["data"])
165+
context=(request, query["data"]))
168166

169167
result = await self.create(query)
170168
result = await self.filter_by_permissions(request, "view", result)
171169
return json_response({"data": result})
172170

173171
async def _update(self, request: web.Request) -> web.Response:
174-
await check_permission(request, f"admin.{self.name}.edit")
172+
await check_permission(request, f"admin.{self.name}.edit", context=(request, None))
175173
query = parse_obj_as(UpdateParams, request.query)
176174

177175
# Check original record is allowed by permission filters.
178176
original = await self.get_one({"id": query["id"]})
179-
if not await permits(request, f"admin.{self.name}.edit", context=original):
177+
if not await permits(request, f"admin.{self.name}.edit", context=(request, original)):
180178
raise web.HTTPForbidden()
181179

182180
# Filter rather than forbid because react-admin still sends fields without an
183181
# input component. The query may not be the complete dict though, so we must
184182
# pass original for testing.
185183
query["data"] = await self.filter_by_permissions(request, "edit", query["data"], original)
186184
# Check new values are allowed by permission filters.
187-
if not await permits(request, f"admin.{self.name}.edit", context=query["data"]):
185+
if not await permits(request, f"admin.{self.name}.edit", context=(request, query["data"])):
188186
raise web.HTTPForbidden()
189187

190188
if not query["data"]:
@@ -195,24 +193,24 @@ async def _update(self, request: web.Request) -> web.Response:
195193
return json_response({"data": result})
196194

197195
async def _delete(self, request: web.Request) -> web.Response:
198-
await check_permission(request, f"admin.{self.name}.delete")
196+
await check_permission(request, f"admin.{self.name}.delete", context=(request, None))
199197
query = parse_obj_as(DeleteParams, request.query)
200198

201199
original = await self.get_one({"id": query["id"]})
202-
if not await permits(request, f"admin.{self.name}.delete", context=original):
200+
if not await permits(request, f"admin.{self.name}.delete", context=(request, original)):
203201
raise web.HTTPForbidden()
204202

205203
result = await self.delete(query)
206204
result = await self.filter_by_permissions(request, "view", result)
207205
return json_response({"data": result})
208206

209207
async def _delete_many(self, request: web.Request) -> web.Response:
210-
await check_permission(request, f"admin.{self.name}.delete")
208+
await check_permission(request, f"admin.{self.name}.delete", context=(request, None))
211209
query = parse_obj_as(DeleteManyParams, request.query)
212210

213211
originals = await self.get_many(query)
214212
allowed = await asyncio.gather(*(permits(request, f"admin.{self.name}.delete",
215-
context=r) for r in originals))
213+
context=(request, r)) for r in originals))
216214
if not all(allowed):
217215
raise web.HTTPForbidden()
218216

aiohttp_admin/security.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,25 @@ async def authorized_userid(self, identity: str) -> str:
6666
return identity
6767

6868
async def permits(self, identity: Optional[str], permission: Union[str, Enum],
69-
context: Optional[Mapping[str, object]] = None) -> bool:
69+
context: tuple[web.Request, Optional[Mapping[str, object]]]) -> bool:
7070
if identity is None:
7171
return False
72-
if self._identity_callback is None:
73-
permissions: Collection[str] = tuple(Permissions)
74-
else:
75-
user = await self._identity_callback(identity)
76-
permissions = user["permissions"]
77-
return has_permission(permission, permissions_as_dict(permissions), context)
72+
73+
try:
74+
request, record = context
75+
except (TypeError, ValueError):
76+
raise TypeError("Context must be `(request, record)` or `(request, None)`")
77+
78+
permissions: Optional[Collection[str]] = request.get("aiohttpadmin_permissions")
79+
if permissions is None:
80+
if self._identity_callback is None:
81+
permissions = tuple(Permissions)
82+
else:
83+
user = await self._identity_callback(identity)
84+
permissions = user["permissions"]
85+
# Cache permissions per request to avoid potentially dozens of DB calls.
86+
request["aiohttpadmin_permissions"] = permissions
87+
return has_permission(permission, permissions_as_dict(permissions), record)
7888

7989

8090
class TokenIdentityPolicy(SessionIdentityPolicy): # type: ignore[misc,no-any-unimported]

tests/test_security.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from typing import Awaitable, Callable, Optional
3+
from unittest import mock
34

45
from aiohttp.test_utils import TestClient
56
from aiohttp_security import AbstractAuthorizationPolicy
@@ -333,6 +334,25 @@ async def identity_callback(identity: Optional[str]) -> UserDetails:
333334
assert await resp.json() == {"data": {"id": 1}}
334335

335336

337+
async def test_permissions_cached(create_admin_client: _CreateClient, # type: ignore[no-any-unimported] # noqa: B950
338+
login: _Login) -> None:
339+
identity_callback = mock.AsyncMock(spec_set=(), return_value={"permissions": {"admin.*"}})
340+
admin_client = await create_admin_client(identity_callback)
341+
342+
assert admin_client.app
343+
url = admin_client.app["admin"].router["dummy2_get_list"].url_for()
344+
h = await login(admin_client)
345+
identity_callback.assert_called_once()
346+
identity_callback.reset_mock()
347+
348+
p = {"pagination": json.dumps({"page": 1, "perPage": 10}),
349+
"sort": json.dumps({"field": "id", "order": "DESC"}), "filter": "{}"}
350+
async with admin_client.get(url, params=p, headers=h) as resp:
351+
assert resp.status == 200
352+
353+
identity_callback.assert_called_once()
354+
355+
336356
async def test_permission_filter_list(create_admin_client: _CreateClient, # type: ignore[no-any-unimported] # noqa: B950
337357
login: _Login) -> None:
338358
async def identity_callback(identity: Optional[str]) -> UserDetails:

0 commit comments

Comments
 (0)