Skip to content

Commit 2132bfa

Browse files
authored
Merge pull request #170 from python-ellar/extra_route_args_fix
Bug Fixes: Extra Route Args multiple field grouping
2 parents ef25d71 + 2464168 commit 2132bfa

File tree

3 files changed

+86
-34
lines changed

3 files changed

+86
-34
lines changed

ellar/common/params/args/base.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -210,36 +210,52 @@ def compute_route_parameter_list(
210210
), "Path params must be of one of the supported types"
211211
self._add_to_model(field=param_field)
212212
else:
213-
default_field_info = t.cast(
214-
t.Type[params.ParamFieldInfo],
215-
param_default
216-
if isinstance(param_default, FieldInfo)
217-
else params.QueryFieldInfo,
218-
)
219-
param_field = get_parameter_field(
213+
param_field = self._process_parameter_file(
220214
param_default=param_default,
221215
param_annotation=param_annotation,
222-
default_field_info=default_field_info,
223216
param_name=param_name,
224217
body_field_class=body_field_class,
225218
)
226-
if not isinstance(
227-
param_field.field_info, (params.BodyFieldInfo, params.FileFieldInfo)
228-
) and not is_scalar_field(field=param_field):
229-
if not is_scalar_sequence_field(param_field):
230-
if not lenient_issubclass(param_field.type_, BaseModel):
231-
raise ImproperConfiguration(
232-
f"{param_field.type_} type can't be processed as a field"
233-
)
234-
235-
bulk_resolver_generator_class = self.get_resolver_generator(
236-
param_default
237-
)
238-
bulk_resolver_generator_class(param_field).generate_resolvers(
239-
body_field_class=body_field_class
240-
)
241219
self._add_to_model(field=param_field)
242220

221+
def _process_parameter_file(
222+
self,
223+
*,
224+
param_default: t.Any,
225+
param_name: str,
226+
param_annotation: t.Type,
227+
body_field_class: t.Type[FieldInfo] = params.BodyFieldInfo,
228+
) -> ModelField:
229+
default_field_info = t.cast(
230+
t.Type[params.ParamFieldInfo],
231+
param_default
232+
if isinstance(param_default, FieldInfo)
233+
else params.QueryFieldInfo,
234+
)
235+
param_field = get_parameter_field(
236+
param_default=param_default,
237+
param_annotation=param_annotation,
238+
default_field_info=default_field_info,
239+
param_name=param_name,
240+
body_field_class=body_field_class,
241+
)
242+
if not isinstance(
243+
param_field.field_info, (params.BodyFieldInfo, params.FileFieldInfo)
244+
) and not is_scalar_field(field=param_field):
245+
if not is_scalar_sequence_field(param_field):
246+
if not lenient_issubclass(param_field.type_, BaseModel):
247+
raise ImproperConfiguration(
248+
f"{param_field.type_} type can't be processed as a field"
249+
)
250+
251+
bulk_resolver_generator_class = self.get_resolver_generator(
252+
param_default
253+
)
254+
bulk_resolver_generator_class(param_field).generate_resolvers(
255+
body_field_class=body_field_class
256+
)
257+
return param_field
258+
243259
def _add_system_parameters_to_dependency(
244260
self,
245261
*,
@@ -351,17 +367,10 @@ def _add_extra_route_args(
351367
):
352368
continue
353369

354-
default_field_info = t.cast(
355-
t.Type[params.ParamFieldInfo],
356-
param_default
357-
if isinstance(param_default, FieldInfo)
358-
else params.QueryFieldInfo,
359-
)
360-
param_field = get_parameter_field(
370+
param_field = self._process_parameter_file(
361371
param_default=param_default,
362-
param_annotation=param.annotation,
363-
default_field_info=default_field_info,
364-
param_name=param.name,
372+
param_annotation=param_annotation,
373+
param_name=param_name,
365374
)
366375
self._add_to_model(field=param_field, key=key)
367376

ellar/common/params/decorators/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing as t
22

3+
from ellar.common.utils import get_name
34
from ellar.pydantic.types import Undefined
45
from typing_extensions import Annotated
56

@@ -27,6 +28,7 @@
2728
class _ParamShortcut:
2829
def __init__(self, base_func: t.Callable) -> None:
2930
self._base_func = base_func
31+
self.name = get_name(base_func)
3032

3133
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
3234
return self._base_func(*args, **kwargs)
@@ -109,7 +111,7 @@ def P(
109111
Path = Annotated[T, param_functions.Path()]
110112
Query = Annotated[T, param_functions.Query()]
111113
WsBody = Annotated[T, param_functions.WsBody()]
112-
Inject = Annotated[T, t.Any]
114+
Inject = InjectShortcut()
113115

114116
else:
115117
Body = _ParamShortcut(param_functions.Body)

tests/test_routing/test_extra_args.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ def _wrapper(*args, **kwargs):
6363
return _wrapper
6464

6565

66+
def add_additional_grouped_field(func):
67+
# EXTRA ARGS SETUP
68+
query1 = ExtraEndpointArg(name="query1", annotation=Query[Filter])
69+
70+
extra_args(query1)(func)
71+
72+
@wraps(func)
73+
def _wrapper(*args, **kwargs):
74+
resolved_query1: Filter = query1.resolve(kwargs)
75+
76+
response = func(*args, **kwargs)
77+
response.update(query1=resolved_query1.dict())
78+
return response
79+
80+
return _wrapper
81+
82+
6683
@get("/test")
6784
@add_extra_non_field_extra_args
6885
@add_additional_signature_to_endpoint
@@ -207,3 +224,27 @@ def test_query_params_extra():
207224

208225
response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50")
209226
assert response.status_code == 422
227+
228+
229+
def test_extra_args_as_grouped_fields():
230+
@get("/test-grouped")
231+
@add_additional_grouped_field
232+
def query_params_extra(
233+
request: Inject[Request],
234+
):
235+
return {}
236+
237+
tm.create_application().router.append(query_params_extra)
238+
client = tm.get_test_client()
239+
response = client.get(
240+
"/test-grouped?from=1&to=2&range=20&foo=1&range2=50&query1=somequery1&query2=somequery2"
241+
)
242+
assert response.json() == {
243+
"query1": {
244+
"from_datetime": "1970-01-01T00:00:01Z",
245+
"range": 20,
246+
"to_datetime": "1970-01-01T00:00:02Z",
247+
}
248+
}
249+
response = client.get("/test-grouped?query1=somequery1&query2=somequery2")
250+
assert response.status_code == 422

0 commit comments

Comments
 (0)