Skip to content

Commit d6e8e26

Browse files
committed
refactored route function resolvers
1 parent 2b28f96 commit d6e8e26

File tree

24 files changed

+390
-215
lines changed

24 files changed

+390
-215
lines changed

ellar/common/params/args/base.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ellar.common.interfaces import IExecutionContext
1313
from ellar.pydantic import (
1414
BaseModel,
15-
ErrorWrapper,
1615
FieldInfo,
1716
ModelField,
1817
evaluate_forwardref,
@@ -30,6 +29,7 @@
3029
IRouteParameterResolver,
3130
SystemParameterResolver,
3231
)
32+
from ..resolvers.base import ResolverResult
3333
from .extra_args import ExtraEndpointArg
3434
from .factory import get_parameter_field
3535
from .resolver_generators import (
@@ -318,27 +318,21 @@ def _add_to_model(self, *, field: ModelField, key: t.Optional[str] = None) -> No
318318
field_info.create_resolver(model_field=field)
319319
)
320320

321-
async def resolve_dependencies(
322-
self, *, ctx: IExecutionContext
323-
) -> t.Tuple[t.Dict[str, t.Any], t.List[ErrorWrapper]]:
324-
values: t.Dict[str, t.Any] = {}
325-
errors: t.List[ErrorWrapper] = []
321+
async def resolve_dependencies(self, *, ctx: IExecutionContext) -> ResolverResult:
322+
body_resolver = await self.resolve_body(ctx)
326323

327-
await self.resolve_body(ctx, values, errors)
328-
329-
if not errors:
324+
if body_resolver and not body_resolver.errors:
330325
for parameter_resolver in self._route_models:
331-
value_, value_errors = await parameter_resolver.resolve(ctx=ctx)
332-
if value_:
333-
values.update(value_)
334-
if value_errors:
326+
res = await parameter_resolver.resolve(ctx=ctx)
327+
if res.data:
328+
body_resolver.data.update(res.data)
329+
if res.errors:
335330
_errors = (
336-
value_errors
337-
if isinstance(value_errors, list)
338-
else [value_errors]
331+
res.errors if isinstance(res.errors, list) else [res.errors]
339332
)
340-
errors += _errors
341-
return values, errors
333+
body_resolver.errors.extend(_errors)
334+
body_resolver.raw_data.update(res.raw_data)
335+
return body_resolver
342336

343337
def compute_extra_route_args(self) -> None:
344338
self._add_extra_route_args(*self._extra_endpoint_args)
@@ -372,10 +366,9 @@ def _add_extra_route_args(
372366
)
373367
self._add_to_model(field=param_field, key=key)
374368

375-
async def resolve_body(
376-
self, ctx: IExecutionContext, values: t.Dict, errors: t.List
377-
) -> None:
369+
async def resolve_body(self, ctx: IExecutionContext) -> ResolverResult:
378370
"""Body Resolver Implementation"""
371+
return ResolverResult({}, [], {})
379372

380373
def __deepcopy__(
381374
self, memodict: t.Optional[t.Dict] = None

ellar/common/params/args/request_model.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
IRouteParameterResolver,
1919
RouteParameterModelField,
2020
)
21+
from ..resolvers.base import ResolverResult
2122
from .base import EndpointArgsModel
2223
from .extra_args import ExtraEndpointArg
2324

@@ -166,23 +167,16 @@ def build_body_field(self) -> None:
166167
final_field = create_model_field(
167168
name="body",
168169
type_=pydantic_body_model,
169-
field_info=body_field_info(**field_info_kwargs),
170+
field_info=body_field_info(**field_info_kwargs, ellar_body=True),
170171
)
171172
final_field.field_info = t.cast(
172173
params.ParamFieldInfo, final_field.field_info
173174
)
174175
check_file_field(final_field)
175176
self.body_resolver = final_field.field_info.create_resolver(final_field)
176177

177-
async def resolve_body(
178-
self, ctx: IExecutionContext, values: t.Dict, errors: t.List
179-
) -> None:
178+
async def resolve_body(self, ctx: IExecutionContext) -> ResolverResult:
180179
if not self.body_resolver:
181-
return
182-
183-
body, errors_ = await self.body_resolver.resolve(ctx=ctx)
184-
if errors_:
185-
assert isinstance(errors_, list)
186-
errors.extend(errors_)
187-
return
188-
values.update(body)
180+
return ResolverResult({}, [], {})
181+
182+
return await self.body_resolver.resolve(ctx=ctx)

ellar/common/params/args/resolver_generators.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
sequence_types,
77
)
88
from ellar.common.exceptions import ImproperConfiguration
9+
from ellar.common.params import params
910
from ellar.pydantic import (
1011
BaseModel,
1112
FieldConstraintsDefaultValues,
@@ -15,8 +16,8 @@
1516
is_scalar_field,
1617
is_scalar_sequence_field,
1718
)
19+
from ellar.pydantic.utils import is_field_annotation_nullable
1820

19-
from .. import params
2021
from .factory import get_parameter_field
2122

2223

@@ -53,11 +54,12 @@ def __init__(self, pydantic_type: ModelField) -> None:
5354
self.param_field = pydantic_type
5455

5556
def validate(self, field_name: str, field: ModelField) -> None:
56-
if not is_scalar_field(field=field):
57-
raise ImproperConfiguration(
58-
f"field: '{field_name}' with annotation:'{field.type_}' in '{self.param_field.type_}'"
59-
f"can't be processed. Field type is not a primitive type"
60-
)
57+
pass
58+
# if not (is_scalar_field(field=field) or is_scalar_sequence_field(field)):
59+
# raise ImproperConfiguration(
60+
# f"field: '{field_name}' with annotation:'{field.type_}' in '{self.param_field.type_}'"
61+
# f"can't be processed. Field type is not a primitive type"
62+
# )
6163

6264
def get_parameter_field(
6365
self,
@@ -82,6 +84,11 @@ def get_parameter_field(
8284
def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
8385
resolvers = []
8486
for k, field in self.pydantic_outer_type.model_fields.items():
87+
field.default = (
88+
None
89+
if is_field_annotation_nullable(field.annotation)
90+
else field.default
91+
)
8592
model_field = create_model_field(
8693
name=k,
8794
type_=field.annotation, # type:ignore[arg-type]
@@ -90,35 +97,42 @@ def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
9097
field_info=field,
9198
)
9299
self.validate(k, model_field)
100+
field_info = model_field.field_info
93101

94-
convert_underscores = getattr(
95-
self.param_field.field_info,
96-
"convert_underscores",
97-
getattr(
98-
self.param_field.field_info.json_schema_extra,
102+
if not isinstance(model_field.field_info, params.ParamFieldInfo):
103+
convert_underscores = getattr(
104+
self.param_field.field_info,
99105
"convert_underscores",
100-
None,
101-
),
102-
)
103-
104-
keys = dict(
105-
FieldConstraintsDefaultValues,
106-
**{"convert_underscores": convert_underscores}
107-
if convert_underscores
108-
else {},
109-
)
110-
111-
attrs = {k: getattr(model_field.field_info, k, v) for k, v in keys.items()}
112-
113-
model_field, field_info = self.get_parameter_field(
114-
k, model_field, attrs, body_field_class
115-
)
116-
resolver = field_info.create_resolver(model_field=model_field)
106+
getattr(
107+
self.param_field.field_info.json_schema_extra,
108+
"convert_underscores",
109+
None,
110+
),
111+
)
112+
113+
keys = dict(
114+
FieldConstraintsDefaultValues,
115+
**model_field.field_info.extract_attributes_keys()
116+
if hasattr(model_field.field_info, "extract_attributes_keys")
117+
else {},
118+
**{"convert_underscores": convert_underscores}
119+
if convert_underscores
120+
else {},
121+
)
122+
123+
attrs = {
124+
k: getattr(model_field.field_info, k, v) for k, v in keys.items()
125+
}
126+
127+
model_field, field_info = self.get_parameter_field(
128+
k, model_field, attrs, body_field_class
129+
)
130+
resolver = field_info.create_resolver(model_field=model_field) # type:ignore[attr-defined]
117131
resolvers.append(resolver)
118132

119133
if isinstance(self.param_field.field_info.json_schema_extra, dict):
120134
self.param_field.field_info.json_schema_extra[MULTI_RESOLVER_KEY] = (
121-
resolvers # type:ignore[assignment]
135+
resolvers
122136
)
123137

124138

ellar/common/params/args/websocket_model.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import typing as t
22

33
from ellar.common.interfaces import IExecutionContext
4-
from ellar.pydantic import ErrorWrapper, FieldInfo
4+
from ellar.pydantic import FieldInfo
55
from starlette.convertors import Convertor
66

77
from .. import params
88
from ..resolvers import BaseRouteParameterResolver, WsBodyParameterResolver
9+
from ..resolvers.base import ResolverResult
910
from .base import EndpointArgsModel
1011
from .extra_args import ExtraEndpointArg
1112

@@ -50,15 +51,17 @@ def compute_route_parameter_list(
5051

5152
async def resolve_ws_body_dependencies(
5253
self, *, ctx: IExecutionContext, body_data: t.Any
53-
) -> t.Tuple[t.Dict[str, t.Any], t.List[ErrorWrapper]]:
54+
) -> ResolverResult:
5455
values: t.Dict[str, t.Any] = {}
55-
errors: t.List[ErrorWrapper] = []
56+
errors = []
57+
raw_data = {}
5658
for parameter_resolver in self.body_resolver or []:
5759
parameter_resolver = t.cast(WsBodyParameterResolver, parameter_resolver)
58-
value_, errors_ = await parameter_resolver.resolve(ctx=ctx, body=body_data)
59-
if value_:
60-
values.update(value_)
61-
if errors_:
62-
assert isinstance(errors_, list)
63-
errors.extend(errors_)
64-
return values, errors
60+
res = await parameter_resolver.resolve(ctx=ctx, body=body_data)
61+
if res.data:
62+
values.update(res.data)
63+
if res.errors:
64+
assert isinstance(res.errors, list)
65+
errors.extend(res.errors)
66+
raw_data.update({parameter_resolver.model_field.name: res.raw_data})
67+
return ResolverResult(values, errors, raw_data)

ellar/common/params/params.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FileParameterResolver,
1616
FormParameterResolver,
1717
HeaderParameterResolver,
18+
IRouteParameterResolver,
1819
PathParameterResolver,
1920
QueryParameterResolver,
2021
WsBodyParameterResolver,
@@ -67,6 +68,9 @@ def __init__(
6768
json_schema_extra: t.Union[t.Dict[str, t.Any], None] = None,
6869
**extra: t.Any,
6970
):
71+
# identifies a field wrapping many resolvers/field-infos
72+
self._ellar_body = extra.pop("ellar_body", False)
73+
7074
self.deprecated = deprecated
7175
self.include_in_schema = include_in_schema
7276

@@ -108,7 +112,9 @@ def __init__(
108112

109113
super().__init__(**init_kwargs)
110114

111-
def create_resolver(self, model_field: ModelField) -> BaseRouteParameterResolver:
115+
def create_resolver(
116+
self, model_field: ModelField
117+
) -> t.Union[BaseRouteParameterResolver, IRouteParameterResolver]:
112118
multiple_resolvers = model_field.field_info.json_schema_extra.pop( # type:ignore[union-attr]
113119
MULTI_RESOLVER_KEY, None
114120
)
@@ -417,7 +423,9 @@ def __init__(
417423
self.embed = True
418424
self.media_type = media_type or self.MEDIA_TYPE
419425

420-
def create_resolver(self, model_field: ModelField) -> BaseRouteParameterResolver:
426+
def create_resolver(
427+
self, model_field: ModelField
428+
) -> t.Union[BaseRouteParameterResolver, IRouteParameterResolver]:
421429
multiple_resolvers = model_field.field_info.json_schema_extra.pop( # type:ignore[union-attr]
422430
MULTI_RESOLVER_KEY, []
423431
)

ellar/common/params/resolvers/base.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from ..params import ParamFieldInfo
1212

1313

14+
class ResolverResult(t.NamedTuple):
15+
data: t.Optional[t.Any]
16+
errors: t.Optional[t.List[t.Dict[str, t.Any]]]
17+
raw_data: t.Dict[str, t.Any]
18+
19+
1420
class RouteParameterModelField(ModelField):
1521
field_info: "ParamFieldInfo"
1622

@@ -20,16 +26,25 @@ class IRouteParameterResolver(ABC, metaclass=ABCMeta):
2026

2127
@abstractmethod
2228
@t.no_type_check
23-
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple:
29+
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
2430
"""Resolve handle"""
2531

32+
@abstractmethod
33+
@t.no_type_check
34+
def create_raw_data(self, data: t.Any) -> t.Dict:
35+
"""Essential for debugging"""
36+
2637

2738
class BaseRouteParameterResolver(IRouteParameterResolver, ABC):
2839
def __init__(self, model_field: ModelField, *args: t.Any, **kwargs: t.Any) -> None:
2940
self.model_field: RouteParameterModelField = t.cast(
3041
RouteParameterModelField, model_field
3142
)
3243

44+
def create_raw_data(self, data: t.Any) -> t.Dict:
45+
"""Essential for debugging"""
46+
return {self.model_field.name: data}
47+
3348
def assert_field_info(self) -> None:
3449
from .. import params
3550

@@ -47,11 +62,11 @@ def validate_error_sequence(cls, errors: t.Any) -> t.List[t.Any]:
4762
return []
4863
return regenerate_error_with_loc(errors=errors, loc_prefix=())
4964

50-
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple:
65+
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
5166
value_ = await self.resolve_handle(*args, **kwargs)
5267
return value_
5368

5469
@abstractmethod
5570
@t.no_type_check
56-
async def resolve_handle(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple:
71+
async def resolve_handle(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
5772
"""resolver action"""

0 commit comments

Comments
 (0)