Skip to content

Commit b734d2b

Browse files
authored
Merge pull request #236 from python-ellar/resolver_refactor
fix: Unified route function resolvers
2 parents a60c571 + caa7bb5 commit b734d2b

File tree

26 files changed

+372
-238
lines changed

26 files changed

+372
-238
lines changed

ellar/common/interfaces/context.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from ellar.common.constants import empty_receive
55
from ellar.common.types import T, TReceive, TScope, TSend
6-
from ellar.di import injectable, request_scope
76
from starlette.requests import empty_send
87
from starlette.responses import Response
98

@@ -40,7 +39,6 @@ def get_client(self) -> "WebSocket":
4039
"""Returns WebSocket instance"""
4140

4241

43-
@injectable(scope=request_scope)
4442
class IHostContext(ABC, metaclass=ABCMeta):
4543
@abstractmethod
4644
def get_service_provider(self) -> "EllarInjector":
@@ -75,7 +73,6 @@ def user(self, value: t.Any) -> None:
7573
"""Sets user identity"""
7674

7775

78-
@injectable(scope=request_scope)
7976
class IExecutionContext(IHostContext, ABC):
8077
@abstractmethod
8178
def get_handler(self) -> t.Callable:
@@ -136,8 +133,8 @@ def create_context_type(self, context: IHostContext) -> T:
136133

137134

138135
class IHTTPConnectionContextFactory(SubHostContextFactory[IHTTPHostContext], ABC):
139-
context_typ: t.Type[IHTTPHostContext]
136+
"""HttpContext Factory Interface"""
140137

141138

142139
class IWebSocketContextFactory(SubHostContextFactory[IWebSocketHostContext], ABC):
143-
context_type: t.Type[IWebSocketHostContext]
140+
"""WebsocketContext Factory Interface"""

ellar/common/params/args/base.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ellar.common.interfaces import IExecutionContext
1313
from ellar.pydantic import (
1414
BaseModel,
15-
ErrorWrapper,
1615
FieldInfo,
1716
ModelField,
1817
evaluate_forwardref,
@@ -30,27 +29,43 @@
3029
IRouteParameterResolver,
3130
SystemParameterResolver,
3231
)
32+
from ..resolvers.base import ResolverResult
3333
from .extra_args import ExtraEndpointArg
3434
from .factory import get_parameter_field
3535
from .resolver_generators import (
3636
BulkArgsResolverGenerator,
37+
CookieResolverGenerator,
3738
FormArgsResolverGenerator,
3839
PathArgsResolverGenerator,
3940
QueryHeaderResolverGenerator,
4041
)
4142

42-
BULK_RESOLVERS = {
43+
__BULK_RESOLVERS__ = {
4344
str(params.FormFieldInfo): FormArgsResolverGenerator,
4445
str(params.PathFieldInfo): PathArgsResolverGenerator,
4546
str(params.QueryFieldInfo): QueryHeaderResolverGenerator,
4647
str(params.HeaderFieldInfo): QueryHeaderResolverGenerator,
48+
str(params.CookieFieldInfo): CookieResolverGenerator,
4749
}
4850

4951

52+
def add_get_resolver_generator(
53+
param: t.Type[params.ParamFieldInfo],
54+
resolver_gen: t.Type[BulkArgsResolverGenerator],
55+
) -> None:
56+
"""
57+
Add a custom Bulk resolver generator a custom route function parameter field type
58+
:param param:
59+
:param resolver_gen:
60+
:return:
61+
"""
62+
__BULK_RESOLVERS__[str(param)] = resolver_gen # pragma: no cover
63+
64+
5065
def get_resolver_generator(
5166
param: params.ParamFieldInfo,
5267
) -> t.Type[BulkArgsResolverGenerator]:
53-
return BULK_RESOLVERS.get(str(type(param)), BulkArgsResolverGenerator)
68+
return __BULK_RESOLVERS__.get(str(type(param)), BulkArgsResolverGenerator)
5469

5570

5671
def get_annotation_type_and_default(
@@ -318,27 +333,21 @@ def _add_to_model(self, *, field: ModelField, key: t.Optional[str] = None) -> No
318333
field_info.create_resolver(model_field=field)
319334
)
320335

321-
async def resolve_dependencies(
322-
self, *, ctx: IExecutionContext
323-
) -> t.Tuple[t.Dict[str, t.Any], t.List[ErrorWrapper]]:
324-
values: t.Dict[str, t.Any] = {}
325-
errors: t.List[ErrorWrapper] = []
326-
327-
await self.resolve_body(ctx, values, errors)
336+
async def resolve_dependencies(self, *, ctx: IExecutionContext) -> ResolverResult:
337+
body_resolver = await self.resolve_body(ctx)
328338

329-
if not errors:
339+
if body_resolver and not body_resolver.errors:
330340
for parameter_resolver in self._route_models:
331-
value_, value_errors = await parameter_resolver.resolve(ctx=ctx)
332-
if value_:
333-
values.update(value_)
334-
if value_errors:
341+
res = await parameter_resolver.resolve(ctx=ctx)
342+
if res.data:
343+
body_resolver.data.update(res.data)
344+
if res.errors:
335345
_errors = (
336-
value_errors
337-
if isinstance(value_errors, list)
338-
else [value_errors]
346+
res.errors if isinstance(res.errors, list) else [res.errors]
339347
)
340-
errors += _errors
341-
return values, errors
348+
body_resolver.errors.extend(_errors)
349+
body_resolver.raw_data.update(res.raw_data)
350+
return body_resolver
342351

343352
def compute_extra_route_args(self) -> None:
344353
self._add_extra_route_args(*self._extra_endpoint_args)
@@ -372,24 +381,9 @@ def _add_extra_route_args(
372381
)
373382
self._add_to_model(field=param_field, key=key)
374383

375-
async def resolve_body(
376-
self, ctx: IExecutionContext, values: t.Dict, errors: t.List
377-
) -> None:
384+
async def resolve_body(self, ctx: IExecutionContext) -> ResolverResult:
378385
"""Body Resolver Implementation"""
379-
380-
def __deepcopy__(
381-
self, memodict: t.Optional[t.Dict] = None
382-
) -> "EndpointArgsModel": # pragma: no cover
383-
if memodict is None:
384-
memodict = {}
385-
return self.__copy__(memodict)
386-
387-
def __copy__(
388-
self, memodict: t.Optional[t.Dict] = None
389-
) -> "EndpointArgsModel": # pragma: no cover
390-
if memodict is None:
391-
memodict = {}
392-
return self
386+
return ResolverResult({}, [], {})
393387

394388
def build_body_field(self) -> None: # pragma: no cover
395389
raise NotImplementedError

ellar/common/params/args/request_model.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
IRouteParameterResolver,
1919
RouteParameterModelField,
2020
)
21+
from ..resolvers.base import ResolverResult
2122
from .base import EndpointArgsModel
2223
from .extra_args import ExtraEndpointArg
2324

@@ -166,23 +167,16 @@ def build_body_field(self) -> None:
166167
final_field = create_model_field(
167168
name="body",
168169
type_=pydantic_body_model,
169-
field_info=body_field_info(**field_info_kwargs),
170+
field_info=body_field_info(**field_info_kwargs, ellar_body=True),
170171
)
171172
final_field.field_info = t.cast(
172173
params.ParamFieldInfo, final_field.field_info
173174
)
174175
check_file_field(final_field)
175176
self.body_resolver = final_field.field_info.create_resolver(final_field)
176177

177-
async def resolve_body(
178-
self, ctx: IExecutionContext, values: t.Dict, errors: t.List
179-
) -> None:
178+
async def resolve_body(self, ctx: IExecutionContext) -> ResolverResult:
180179
if not self.body_resolver:
181-
return
182-
183-
body, errors_ = await self.body_resolver.resolve(ctx=ctx)
184-
if errors_:
185-
assert isinstance(errors_, list)
186-
errors.extend(errors_)
187-
return
188-
values.update(body)
180+
return ResolverResult({}, [], {})
181+
182+
return await self.body_resolver.resolve(ctx=ctx)

ellar/common/params/args/resolver_generators.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
sequence_types,
77
)
88
from ellar.common.exceptions import ImproperConfiguration
9+
from ellar.common.params import params
910
from ellar.pydantic import (
1011
BaseModel,
1112
FieldConstraintsDefaultValues,
@@ -16,7 +17,6 @@
1617
is_scalar_sequence_field,
1718
)
1819

19-
from .. import params
2020
from .factory import get_parameter_field
2121

2222

@@ -53,7 +53,7 @@ def __init__(self, pydantic_type: ModelField) -> None:
5353
self.param_field = pydantic_type
5454

5555
def validate(self, field_name: str, field: ModelField) -> None:
56-
if not is_scalar_field(field=field):
56+
if not (is_scalar_field(field=field) or is_scalar_sequence_field(field)):
5757
raise ImproperConfiguration(
5858
f"field: '{field_name}' with annotation:'{field.type_}' in '{self.param_field.type_}'"
5959
f"can't be processed. Field type is not a primitive type"
@@ -103,6 +103,9 @@ def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
103103

104104
keys = dict(
105105
FieldConstraintsDefaultValues,
106+
**model_field.field_info.extract_attributes_keys()
107+
if hasattr(model_field.field_info, "extract_attributes_keys")
108+
else {},
106109
**{"convert_underscores": convert_underscores}
107110
if convert_underscores
108111
else {},
@@ -132,6 +135,15 @@ def validate(self, field_name: str, field: ModelField) -> None:
132135
)
133136

134137

138+
class CookieResolverGenerator(BulkArgsResolverGenerator):
139+
def validate(self, field_name: str, field: ModelField) -> None:
140+
if not is_scalar_field(field=field):
141+
raise ImproperConfiguration(
142+
f"field: '{field_name}' with annotation:'{field.type_}' in '{self.param_field.type_}'"
143+
f"can't be processed. Field type is not a primitive type"
144+
)
145+
146+
135147
class FormArgsResolverGenerator(QueryHeaderResolverGenerator):
136148
def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
137149
super().generate_resolvers(body_field_class=body_field_class)
@@ -140,13 +152,7 @@ def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
140152
] = True
141153

142154

143-
class PathArgsResolverGenerator(BulkArgsResolverGenerator):
144-
def validate(self, field_name: str, field: ModelField) -> None:
145-
if not is_scalar_field(field=field):
146-
raise ImproperConfiguration(
147-
"Path params must be of one of the supported types. Only primitive types"
148-
)
149-
155+
class PathArgsResolverGenerator(CookieResolverGenerator):
150156
def get_parameter_field(
151157
self,
152158
field_name: str,

ellar/common/params/args/websocket_model.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import typing as t
22

33
from ellar.common.interfaces import IExecutionContext
4-
from ellar.pydantic import ErrorWrapper, FieldInfo
4+
from ellar.pydantic import FieldInfo
55
from starlette.convertors import Convertor
66

77
from .. import params
88
from ..resolvers import BaseRouteParameterResolver, WsBodyParameterResolver
9+
from ..resolvers.base import ResolverResult
910
from .base import EndpointArgsModel
1011
from .extra_args import ExtraEndpointArg
1112

@@ -50,15 +51,17 @@ def compute_route_parameter_list(
5051

5152
async def resolve_ws_body_dependencies(
5253
self, *, ctx: IExecutionContext, body_data: t.Any
53-
) -> t.Tuple[t.Dict[str, t.Any], t.List[ErrorWrapper]]:
54+
) -> ResolverResult:
5455
values: t.Dict[str, t.Any] = {}
55-
errors: t.List[ErrorWrapper] = []
56+
errors = []
57+
raw_data = {}
5658
for parameter_resolver in self.body_resolver or []:
5759
parameter_resolver = t.cast(WsBodyParameterResolver, parameter_resolver)
58-
value_, errors_ = await parameter_resolver.resolve(ctx=ctx, body=body_data)
59-
if value_:
60-
values.update(value_)
61-
if errors_:
62-
assert isinstance(errors_, list)
63-
errors.extend(errors_)
64-
return values, errors
60+
res = await parameter_resolver.resolve(ctx=ctx, body=body_data)
61+
if res.data:
62+
values.update(res.data)
63+
if res.errors:
64+
assert isinstance(res.errors, list)
65+
errors.extend(res.errors)
66+
raw_data.update({parameter_resolver.model_field.name: res.raw_data})
67+
return ResolverResult(values, errors, raw_data)

ellar/common/params/params.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FileParameterResolver,
1616
FormParameterResolver,
1717
HeaderParameterResolver,
18+
IRouteParameterResolver,
1819
PathParameterResolver,
1920
QueryParameterResolver,
2021
WsBodyParameterResolver,
@@ -67,6 +68,9 @@ def __init__(
6768
json_schema_extra: t.Union[t.Dict[str, t.Any], None] = None,
6869
**extra: t.Any,
6970
):
71+
# identifies a field wrapping many resolvers/field-infos
72+
self._ellar_body = extra.pop("ellar_body", False)
73+
7074
self.deprecated = deprecated
7175
self.include_in_schema = include_in_schema
7276

@@ -108,7 +112,9 @@ def __init__(
108112

109113
super().__init__(**init_kwargs)
110114

111-
def create_resolver(self, model_field: ModelField) -> BaseRouteParameterResolver:
115+
def create_resolver(
116+
self, model_field: ModelField
117+
) -> t.Union[BaseRouteParameterResolver, IRouteParameterResolver]:
112118
multiple_resolvers = model_field.field_info.json_schema_extra.pop( # type:ignore[union-attr]
113119
MULTI_RESOLVER_KEY, None
114120
)
@@ -417,7 +423,9 @@ def __init__(
417423
self.embed = True
418424
self.media_type = media_type or self.MEDIA_TYPE
419425

420-
def create_resolver(self, model_field: ModelField) -> BaseRouteParameterResolver:
426+
def create_resolver(
427+
self, model_field: ModelField
428+
) -> t.Union[BaseRouteParameterResolver, IRouteParameterResolver]:
421429
multiple_resolvers = model_field.field_info.json_schema_extra.pop( # type:ignore[union-attr]
422430
MULTI_RESOLVER_KEY, []
423431
)

ellar/common/params/resolvers/base.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from ..params import ParamFieldInfo
1212

1313

14+
class ResolverResult(t.NamedTuple):
15+
data: t.Optional[t.Any]
16+
errors: t.Optional[t.List[t.Dict[str, t.Any]]]
17+
raw_data: t.Dict[str, t.Any]
18+
19+
1420
class RouteParameterModelField(ModelField):
1521
field_info: "ParamFieldInfo"
1622

@@ -20,16 +26,25 @@ class IRouteParameterResolver(ABC, metaclass=ABCMeta):
2026

2127
@abstractmethod
2228
@t.no_type_check
23-
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple:
29+
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
2430
"""Resolve handle"""
2531

32+
@abstractmethod
33+
@t.no_type_check
34+
def create_raw_data(self, data: t.Any) -> t.Dict:
35+
"""Essential for debugging"""
36+
2637

2738
class BaseRouteParameterResolver(IRouteParameterResolver, ABC):
2839
def __init__(self, model_field: ModelField, *args: t.Any, **kwargs: t.Any) -> None:
2940
self.model_field: RouteParameterModelField = t.cast(
3041
RouteParameterModelField, model_field
3142
)
3243

44+
def create_raw_data(self, data: t.Any) -> t.Dict:
45+
"""Essential for debugging"""
46+
return {self.model_field.name: data}
47+
3348
def assert_field_info(self) -> None:
3449
from .. import params
3550

@@ -47,11 +62,11 @@ def validate_error_sequence(cls, errors: t.Any) -> t.List[t.Any]:
4762
return []
4863
return regenerate_error_with_loc(errors=errors, loc_prefix=())
4964

50-
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple:
65+
async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
5166
value_ = await self.resolve_handle(*args, **kwargs)
5267
return value_
5368

5469
@abstractmethod
5570
@t.no_type_check
56-
async def resolve_handle(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple:
71+
async def resolve_handle(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
5772
"""resolver action"""

0 commit comments

Comments
 (0)