Skip to content

Commit 9dc8f41

Browse files
authored
269 Stop multi-dimensional array and email filtering from breaking (#270)
* stop multidimensional arrays from breaking `FastAPIWrapper` * replace `list[list]` with `list` * fix linter errors * wip * fix `ModelFilters` * update `pydantic_model_filters` * use `Array. _get_dimensions` instead of `is_multidimensional_array` * use latest piccolo * ignore mypy warning * add test for filtering email * add a test for multidimensional arrays
1 parent 2df940d commit 9dc8f41

File tree

10 files changed

+317
-109
lines changed

10 files changed

+317
-109
lines changed

piccolo_api/crud/endpoints.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from piccolo.query.methods.select import Select
2626
from piccolo.table import Table
2727
from piccolo.utils.encoding import dump_json
28+
from piccolo.utils.pydantic import create_pydantic_model
2829
from starlette.requests import Request
2930
from starlette.responses import JSONResponse, Response
3031
from starlette.routing import Route, Router
@@ -38,7 +39,6 @@
3839
)
3940

4041
from .exceptions import MalformedQuery, db_exception_handler
41-
from .serializers import create_pydantic_model
4242
from .validators import Validators, apply_validators
4343

4444
if t.TYPE_CHECKING: # pragma: no cover
@@ -257,9 +257,9 @@ def __init__(
257257
table=table, exclude_secrets=exclude_secrets, max_joins=max_joins
258258
)
259259
schema_extra["visible_fields_options"] = self.visible_fields_options
260-
schema_extra[
261-
"primary_key_name"
262-
] = self.table._meta.primary_key._meta.name
260+
schema_extra["primary_key_name"] = (
261+
self.table._meta.primary_key._meta.name
262+
)
263263
self.schema_extra = schema_extra
264264

265265
root_methods = ["GET"]
@@ -282,9 +282,9 @@ def __init__(
282282
Route(
283283
path="/{row_id:str}/",
284284
endpoint=self.detail,
285-
methods=["GET"]
286-
if read_only
287-
else ["GET", "PUT", "DELETE", "PATCH"],
285+
methods=(
286+
["GET"] if read_only else ["GET", "PUT", "DELETE", "PATCH"]
287+
),
288288
),
289289
]
290290

@@ -330,8 +330,8 @@ def pydantic_model_output(self) -> t.Type[pydantic.BaseModel]:
330330
@property
331331
def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]:
332332
"""
333-
All fields are optional, which is useful for serialising filters,
334-
where a user can filter on any number of fields.
333+
All fields are optional, which is useful for PATCH requests, which
334+
may only update some fields.
335335
"""
336336
return create_pydantic_model(
337337
self.table,
@@ -340,6 +340,63 @@ def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]:
340340
model_name=f"{self.table.__name__}Optional",
341341
)
342342

343+
@property
344+
def pydantic_model_filters(self) -> t.Type[pydantic.BaseModel]:
345+
"""
346+
Used for serialising query params, which are used for filtering.
347+
348+
A special case is multidimensional arrays - if we have this::
349+
350+
my_column = Array(Array(Varchar()))
351+
352+
Even though the type is ``list[list[str]]``, this isn't allowed as a
353+
query parameter. Instead, we use ``list[str]``.
354+
355+
Also, for ``Email`` columns, we don't want to validate that it's a
356+
correct email address when filtering, as someone may want to filter
357+
by 'gmail', for example.
358+
359+
"""
360+
model_name = f"{self.table.__name__}Filters"
361+
362+
multidimensional_array_columns = [
363+
i
364+
for i in self.table._meta.array_columns
365+
if i._get_dimensions() > 1
366+
]
367+
368+
email_columns = self.table._meta.email_columns
369+
370+
base_model = create_pydantic_model(
371+
self.table,
372+
include_default_columns=True,
373+
exclude_columns=(*multidimensional_array_columns, *email_columns),
374+
all_optional=True,
375+
model_name=model_name,
376+
)
377+
378+
if multidimensional_array_columns or email_columns:
379+
return pydantic.create_model(
380+
__model_name=model_name,
381+
__base__=base_model,
382+
**{
383+
i._meta.name: (
384+
t.Optional[t.List[i._get_inner_value_type()]], # type: ignore # noqa: E501
385+
pydantic.Field(default=None),
386+
)
387+
for i in multidimensional_array_columns
388+
},
389+
**{
390+
i._meta.name: (
391+
t.Optional[str],
392+
pydantic.Field(default=None),
393+
)
394+
for i in email_columns
395+
},
396+
)
397+
else:
398+
return base_model
399+
343400
def pydantic_model_plural(
344401
self,
345402
include_readable=False,
@@ -716,7 +773,7 @@ def _apply_filters(
716773
"""
717774
fields = params.fields
718775
if fields:
719-
model_dict = self.pydantic_model_optional(**fields).model_dump()
776+
model_dict = self.pydantic_model_filters(**fields).model_dump()
720777
for field_name in fields.keys():
721778
value = model_dict.get(field_name, ...)
722779
if value is ...:
@@ -860,9 +917,9 @@ async def get_all(
860917
curr_page_len = curr_page_len + offset
861918
count = await self.table.count().run()
862919
curr_page_string = f"{offset}-{curr_page_len}"
863-
headers[
864-
"Content-Range"
865-
] = f"{plural_name} {curr_page_string}/{count}"
920+
headers["Content-Range"] = (
921+
f"{plural_name} {curr_page_string}/{count}"
922+
)
866923

867924
# We need to serialise it ourselves, in case there are datetime
868925
# fields.
@@ -1155,9 +1212,13 @@ async def patch_single(
11551212
cls = self.table
11561213

11571214
try:
1158-
values = {getattr(cls, key): getattr(model, key) for key in data.keys()}
1215+
values = {
1216+
getattr(cls, key): getattr(model, key) for key in data.keys()
1217+
}
11591218
except AttributeError:
1160-
unrecognised_keys = set(data.keys()) - set(model.model_dump().keys())
1219+
unrecognised_keys = set(data.keys()) - set(
1220+
model.model_dump().keys()
1221+
)
11611222
return Response(
11621223
f"Unrecognised keys - {unrecognised_keys}.",
11631224
status_code=400,
@@ -1180,15 +1241,19 @@ async def patch_single(
11801241
return Response(f"{e}", status_code=400)
11811242
values["password"] = cls.hash_password(password)
11821243
try:
1183-
await cls.update(values).where(cls._meta.primary_key == row_id).run()
1244+
await cls.update(values).where(
1245+
cls._meta.primary_key == row_id
1246+
).run()
11841247
new_row = (
11851248
await cls.select(exclude_secrets=self.exclude_secrets)
11861249
.where(cls._meta.primary_key == row_id)
11871250
.first()
11881251
.run()
11891252
)
11901253
assert new_row
1191-
return CustomJSONResponse(self.pydantic_model(**new_row).model_dump_json())
1254+
return CustomJSONResponse(
1255+
self.pydantic_model(**new_row).model_dump_json()
1256+
)
11921257
except ValueError:
11931258
return Response("Unable to save the resource.", status_code=500)
11941259

piccolo_api/fastapi/endpoints.py

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,11 @@
1717
from pydantic.main import BaseModel
1818

1919
from piccolo_api.crud.endpoints import PiccoloCRUD
20+
from piccolo_api.utils.types import get_type
2021

2122
ANNOTATIONS: t.DefaultDict = defaultdict(dict)
2223

2324

24-
try:
25-
# Python 3.10 and above
26-
from types import UnionType # type: ignore
27-
except ImportError:
28-
29-
class UnionType: # type: ignore
30-
...
31-
32-
3325
class HTTPMethod(str, Enum):
3426
get = "GET"
3527
delete = "DELETE"
@@ -85,41 +77,6 @@ class ReferencesModel(BaseModel):
8577
references: t.List[ReferenceModel]
8678

8779

88-
def _get_type(type_: t.Type) -> t.Type:
89-
"""
90-
Extract the inner type from an optional if necessary, otherwise return
91-
the type as is.
92-
93-
For example::
94-
95-
>>> _get_type(Optional[int])
96-
int
97-
98-
>>> _get_type(int | None)
99-
int
100-
101-
>>> _get_type(int)
102-
int
103-
104-
>>> _get_type(list[str])
105-
list[str]
106-
107-
"""
108-
origin = t.get_origin(type_)
109-
110-
# Note: even if `t.Optional` is passed in, the origin is still a
111-
# `t.Union` or `UnionType` depending on the Python version.
112-
if any(origin is i for i in (t.Union, UnionType)):
113-
union_args = t.get_args(type_)
114-
115-
NoneType = type(None)
116-
117-
if len(union_args) == 2 and NoneType in union_args:
118-
return [i for i in union_args if i is not NoneType][0]
119-
120-
return type_
121-
122-
12380
class FastAPIWrapper:
12481
"""
12582
Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI.
@@ -160,6 +117,7 @@ def __init__(
160117
self.ModelIn = piccolo_crud.pydantic_model
161118
self.ModelOptional = piccolo_crud.pydantic_model_optional
162119
self.ModelPlural = piccolo_crud.pydantic_model_plural()
120+
self.ModelFilters = piccolo_crud.pydantic_model_filters
163121

164122
self.alias = f"{piccolo_crud.table._meta.tablename}__{id(self)}"
165123

@@ -180,7 +138,7 @@ async def get(request: Request, **kwargs):
180138

181139
self.modify_signature(
182140
endpoint=get,
183-
model=self.ModelOut,
141+
model=self.ModelFilters,
184142
http_method=HTTPMethod.get,
185143
allow_ordering=True,
186144
allow_pagination=True,
@@ -243,7 +201,7 @@ async def count(request: Request, **kwargs):
243201
return await piccolo_crud.get_count(request=request)
244202

245203
self.modify_signature(
246-
endpoint=count, model=self.ModelOut, http_method=HTTPMethod.get
204+
endpoint=count, model=self.ModelFilters, http_method=HTTPMethod.get
247205
)
248206

249207
fastapi_app.add_api_route(
@@ -301,7 +259,7 @@ async def delete(request: Request, **kwargs):
301259

302260
self.modify_signature(
303261
endpoint=delete,
304-
model=self.ModelOut,
262+
model=self.ModelFilters,
305263
http_method=HTTPMethod.delete,
306264
)
307265

@@ -325,9 +283,9 @@ async def post(request: Request, model):
325283
"""
326284
return await piccolo_crud.root(request=request)
327285

328-
post.__annotations__[
329-
"model"
330-
] = f"ANNOTATIONS['{self.alias}']['ModelIn']"
286+
post.__annotations__["model"] = (
287+
f"ANNOTATIONS['{self.alias}']['ModelIn']"
288+
)
331289

332290
fastapi_app.add_api_route(
333291
path=root_url,
@@ -386,9 +344,9 @@ async def put(row_id: str, request: Request, model):
386344
"""
387345
return await piccolo_crud.detail(request=request)
388346

389-
put.__annotations__[
390-
"model"
391-
] = f"ANNOTATIONS['{self.alias}']['ModelIn']"
347+
put.__annotations__["model"] = (
348+
f"ANNOTATIONS['{self.alias}']['ModelIn']"
349+
)
392350

393351
fastapi_app.add_api_route(
394352
path=self.join_urls(root_url, "/{row_id:str}/"),
@@ -410,9 +368,9 @@ async def patch(row_id: str, request: Request, model):
410368
"""
411369
return await piccolo_crud.detail(request=request)
412370

413-
patch.__annotations__[
414-
"model"
415-
] = f"ANNOTATIONS['{self.alias}']['ModelOptional']"
371+
patch.__annotations__["model"] = (
372+
f"ANNOTATIONS['{self.alias}']['ModelOptional']"
373+
)
416374

417375
fastapi_app.add_api_route(
418376
path=self.join_urls(root_url, "/{row_id:str}/"),
@@ -460,7 +418,7 @@ def modify_signature(
460418
for field_name, _field in model.model_fields.items():
461419
annotation = _field.annotation
462420
assert annotation is not None
463-
type_ = _get_type(annotation)
421+
type_ = get_type(annotation)
464422

465423
parameters.append(
466424
Parameter(

piccolo_api/utils/__init__.py

Whitespace-only changes.

piccolo_api/utils/types.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
Utils for extracting information from complex, nested types.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import typing as t
8+
9+
try:
10+
# Python 3.10 and above
11+
from types import UnionType # type: ignore
12+
except ImportError:
13+
14+
class UnionType: # type: ignore
15+
...
16+
17+
18+
def get_type(type_: t.Type) -> t.Type:
19+
"""
20+
Extract the inner type from an optional if necessary, otherwise return
21+
the type as is.
22+
23+
For example::
24+
25+
>>> get_type(Optional[int])
26+
int
27+
28+
>>> get_type(int | None)
29+
int
30+
31+
>>> get_type(int)
32+
int
33+
34+
>>> _get_type(list[str])
35+
list[str]
36+
37+
"""
38+
origin = t.get_origin(type_)
39+
40+
# Note: even if `t.Optional` is passed in, the origin is still a
41+
# `t.Union` or `UnionType` depending on the Python version.
42+
if any(origin is i for i in (t.Union, UnionType)):
43+
union_args = t.get_args(type_)
44+
45+
NoneType = type(None)
46+
47+
if len(union_args) == 2 and NoneType in union_args:
48+
return [i for i in union_args if i is not NoneType][0]
49+
50+
return type_
51+
52+
53+
__all__ = ("get_type",)

requirements/dev-requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
black==24.3.0
2-
isort==5.12.0
3-
twine==4.0.2
4-
mypy==1.5.1
2+
isort==5.13.2
3+
twine==5.0.0
4+
mypy==1.9.0
55
pip-upgrader==1.4.15
6-
wheel==0.41.2
7-
setuptools==68.2.2
6+
wheel==0.43.0
7+
setuptools==69.2.0

requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Jinja2>=2.11.0
2-
piccolo[postgres]>=1.0a3
2+
piccolo[postgres]>=1.5
33
pydantic[email]>=2.0
44
python-multipart>=0.0.5
55
fastapi>=0.100.0

0 commit comments

Comments
 (0)