Skip to content

Commit b22a618

Browse files
Request Context Inside Hook Callbacks (#167)
* Found it could be useful to have access to the request context inside of hooks. Here is an implementation that gives that option. Wrote tests for it. Hope this is suitable. * Passing request context to hook callbacks. Here is an implementation that gives that option. Wrote tests for it. Hope this is suitable. Ran lint script * Fixed the helper function name in test_hooks * fix mypy error * add docs for dependency injection Co-authored-by: Daniel Townsend <[email protected]>
1 parent c79baf6 commit b22a618

File tree

4 files changed

+168
-10
lines changed

4 files changed

+168
-10
lines changed

docs/source/crud/hooks.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pre_save
6464
~~~~~~~~
6565

6666
This hook runs during POST requests, prior to inserting data into the database.
67-
It takes a single parameter, ``row``, and should return the same:
67+
It takes a single parameter, ``row``, and should return the row:
6868

6969
.. code-block:: python
7070
@@ -131,6 +131,17 @@ It takes one parameter, ``row_id`` which is the id of the row to be deleted.
131131
]
132132
)
133133
134+
Dependency injection
135+
~~~~~~~~~~~~~~~~~~~~
136+
137+
Each hook can optionally receive the ``Starlette`` request object. Just
138+
add ``request`` as an argument in your hook, and it'll be injected automatically.
139+
140+
.. code-block:: python
141+
142+
async def set_movie_rating_10(row: Movie, request: Request):
143+
...
144+
134145
-------------------------------------------------------------------------------
135146

136147
Source

piccolo_api/crud/endpoints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,10 @@ async def post_single(
810810
row = self.table(**model.dict())
811811
if self._hook_map:
812812
row = await execute_post_hooks(
813-
hooks=self._hook_map, hook_type=HookType.pre_save, row=row
813+
hooks=self._hook_map,
814+
hook_type=HookType.pre_save,
815+
row=row,
816+
request=request,
814817
)
815818
response = await row.save().run()
816819
json = dump_json(response)
@@ -1054,6 +1057,7 @@ async def patch_single(
10541057
hook_type=HookType.pre_patch,
10551058
row_id=row_id,
10561059
values=values,
1060+
request=request,
10571061
)
10581062

10591063
try:
@@ -1083,6 +1087,7 @@ async def delete_single(
10831087
hooks=self._hook_map,
10841088
hook_type=HookType.pre_delete,
10851089
row_id=row_id,
1090+
request=request,
10861091
)
10871092

10881093
try:

piccolo_api/crud/hooks.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from enum import Enum
44

55
from piccolo.table import Table
6+
from starlette.requests import Request
67

78

89
class HookType(Enum):
@@ -22,13 +23,23 @@ def __init__(self, hook_type: HookType, callable: t.Callable) -> None:
2223

2324

2425
async def execute_post_hooks(
25-
hooks: t.Dict[HookType, t.List[Hook]], hook_type: HookType, row: Table
26+
hooks: t.Dict[HookType, t.List[Hook]],
27+
hook_type: HookType,
28+
row: Table,
29+
request: Request,
2630
):
2731
for hook in hooks.get(hook_type, []):
32+
signature = inspect.signature(hook.callable)
33+
kwargs: t.Dict[str, t.Any] = dict(row=row)
34+
# Include request in hook call arguments if possible
35+
if {i for i in signature.parameters.keys()}.intersection(
36+
{"kwargs", "request"}
37+
):
38+
kwargs.update(request=request)
2839
if inspect.iscoroutinefunction(hook.callable):
29-
row = await hook.callable(row)
40+
row = await hook.callable(**kwargs)
3041
else:
31-
row = hook.callable(row)
42+
row = hook.callable(**kwargs)
3243
return row
3344

3445

@@ -37,20 +48,38 @@ async def execute_patch_hooks(
3748
hook_type: HookType,
3849
row_id: t.Any,
3950
values: t.Dict[t.Any, t.Any],
51+
request: Request,
4052
) -> t.Dict[t.Any, t.Any]:
4153
for hook in hooks.get(hook_type, []):
54+
signature = inspect.signature(hook.callable)
55+
kwargs = dict(row_id=row_id, values=values)
56+
# Include request in hook call arguments if possible
57+
if {i for i in signature.parameters.keys()}.intersection(
58+
{"kwargs", "request"}
59+
):
60+
kwargs.update(request=request)
4261
if inspect.iscoroutinefunction(hook.callable):
43-
values = await hook.callable(row_id=row_id, values=values)
62+
values = await hook.callable(**kwargs)
4463
else:
45-
values = hook.callable(row_id=row_id, values=values)
64+
values = hook.callable(**kwargs)
4665
return values
4766

4867

4968
async def execute_delete_hooks(
50-
hooks: t.Dict[HookType, t.List[Hook]], hook_type: HookType, row_id: t.Any
69+
hooks: t.Dict[HookType, t.List[Hook]],
70+
hook_type: HookType,
71+
row_id: t.Any,
72+
request: Request,
5173
):
5274
for hook in hooks.get(hook_type, []):
75+
signature = inspect.signature(hook.callable)
76+
kwargs = dict(row_id=row_id)
77+
# Include request in hook call arguments if possible
78+
if {i for i in signature.parameters.keys()}.intersection(
79+
{"kwargs", "request"}
80+
):
81+
kwargs.update(request=request)
5382
if inspect.iscoroutinefunction(hook.callable):
54-
await hook.callable(row_id=row_id)
83+
await hook.callable(**kwargs)
5584
else:
56-
hook.callable(row_id=row_id)
85+
hook.callable(**kwargs)

tests/crud/test_hooks.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest import TestCase
22

3+
from fastapi import Request
34
from piccolo.columns import Integer, Varchar
45
from piccolo.columns.readable import Readable
56
from piccolo.table import Table
@@ -39,17 +40,62 @@ async def look_up_existing(row_id: int, values: dict):
3940
return values
4041

4142

43+
async def add_additional_name_details(
44+
row_id: int, values: dict, request: Request
45+
):
46+
director = request.query_params.get("director_name", "")
47+
values["name"] = values["name"] + f" ({director})"
48+
return values
49+
50+
51+
async def additional_name_details(row: Movie, request: Request):
52+
director = request.query_params.get("director_name", "")
53+
row["name"] = f"{row.name} ({director})"
54+
return row
55+
56+
57+
async def raises_exception(row_id: int, request: Request):
58+
if request.query_params.get("director_name", False):
59+
raise Exception("Test Passed")
60+
61+
4262
async def failing_hook(row_id: int):
4363
raise Exception("hook failed")
4464

4565

66+
# TODO - add test for a non-async hook.
4667
class TestPostHooks(TestCase):
4768
def setUp(self):
4869
Movie.create_table(if_not_exists=True).run_sync()
4970

5071
def tearDown(self):
5172
Movie.alter().drop_table().run_sync()
5273

74+
def test_request_context_passed_to_post_hook(self):
75+
"""
76+
Make sure request context can be passed to post hook
77+
callable
78+
"""
79+
client = TestClient(
80+
PiccoloCRUD(
81+
table=Movie,
82+
read_only=False,
83+
hooks=[
84+
Hook(
85+
hook_type=HookType.pre_save,
86+
callable=additional_name_details,
87+
)
88+
],
89+
)
90+
)
91+
json_req = {
92+
"name": "Star Wars",
93+
"rating": 93,
94+
}
95+
_ = client.post("/", json=json_req, params={"director_name": "George"})
96+
movie = Movie.objects().first().run_sync()
97+
self.assertEqual(movie.name, "Star Wars (George)")
98+
5399
def test_single_pre_post_hook(self):
54100
"""
55101
Make sure single hook executes
@@ -96,6 +142,47 @@ def test_multi_pre_post_hooks(self):
96142
movie = Movie.objects().first().run_sync()
97143
self.assertEqual(movie.rating, 20)
98144

145+
def test_request_context_passed_to_patch_hook(self):
146+
"""
147+
Make sure request context can be passed to patch hook
148+
callable
149+
"""
150+
client = TestClient(
151+
PiccoloCRUD(
152+
table=Movie,
153+
read_only=False,
154+
hooks=[
155+
Hook(
156+
hook_type=HookType.pre_patch,
157+
callable=add_additional_name_details,
158+
)
159+
],
160+
)
161+
)
162+
163+
movie = Movie(name="Star Wars", rating=93)
164+
movie.save().run_sync()
165+
166+
new_name = "Star Wars: A New Hope"
167+
new_name_modified = new_name + " (George)"
168+
169+
json_req = {
170+
"name": new_name,
171+
}
172+
173+
response = client.patch(
174+
f"/{movie.id}/", json=json_req, params={"director_name": "George"}
175+
)
176+
self.assertTrue(response.status_code == 200)
177+
178+
# Make sure the row is returned:
179+
response_json = response.json()
180+
self.assertTrue(response_json["name"] == new_name_modified)
181+
182+
# Make sure the underlying database row was changed:
183+
movies = Movie.select().run_sync()
184+
self.assertTrue(movies[0]["name"] == new_name_modified)
185+
99186
def test_pre_patch_hook(self):
100187
"""
101188
Make sure pre_patch hook executes successfully
@@ -159,6 +246,32 @@ def test_pre_patch_hook_db_lookup(self):
159246
movies = Movie.select().run_sync()
160247
self.assertTrue(movies[0]["name"] == original_name)
161248

249+
def test_request_context_passed_to_delete_hook(self):
250+
"""
251+
Make sure request context can be passed to patch hook
252+
callable
253+
"""
254+
client = TestClient(
255+
PiccoloCRUD(
256+
table=Movie,
257+
read_only=False,
258+
hooks=[
259+
Hook(
260+
hook_type=HookType.pre_delete,
261+
callable=raises_exception,
262+
)
263+
],
264+
)
265+
)
266+
267+
movie = Movie(name="Star Wars", rating=10)
268+
movie.save().run_sync()
269+
270+
with self.assertRaises(Exception, msg="Test Passed"):
271+
_ = client.delete(
272+
f"/{movie.id}/", params={"director_name": "George"}
273+
)
274+
162275
def test_delete_hook_fails(self):
163276
"""
164277
Make sure failing pre_delete hook bubbles up

0 commit comments

Comments
 (0)