Skip to content

Commit 9a00126

Browse files
committed
fixed failing test
1 parent d6e8e26 commit 9a00126

File tree

2 files changed

+45
-38
lines changed

2 files changed

+45
-38
lines changed

ellar/common/params/args/base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,30 @@
3939
QueryHeaderResolverGenerator,
4040
)
4141

42-
BULK_RESOLVERS = {
42+
__BULK_RESOLVERS__ = {
4343
str(params.FormFieldInfo): FormArgsResolverGenerator,
4444
str(params.PathFieldInfo): PathArgsResolverGenerator,
4545
str(params.QueryFieldInfo): QueryHeaderResolverGenerator,
4646
str(params.HeaderFieldInfo): QueryHeaderResolverGenerator,
4747
}
4848

4949

50+
def add_get_resolver_generator(
51+
param: params.ParamFieldInfo, resolver_gen: t.Type[BulkArgsResolverGenerator]
52+
) -> None:
53+
"""
54+
Add a custom Bulk resolver generator a custom route function parameter field type
55+
:param param:
56+
:param resolver_gen:
57+
:return:
58+
"""
59+
__BULK_RESOLVERS__[str(param)] = resolver_gen # pragma: no cover
60+
61+
5062
def get_resolver_generator(
5163
param: params.ParamFieldInfo,
5264
) -> t.Type[BulkArgsResolverGenerator]:
53-
return BULK_RESOLVERS.get(str(type(param)), BulkArgsResolverGenerator)
65+
return __BULK_RESOLVERS__.get(str(type(param)), BulkArgsResolverGenerator)
5466

5567

5668
def get_annotation_type_and_default(

ellar/common/params/args/resolver_generators.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def __init__(self, pydantic_type: ModelField) -> None:
5454
self.param_field = pydantic_type
5555

5656
def validate(self, field_name: str, field: ModelField) -> None:
57-
pass
58-
# if not (is_scalar_field(field=field) or is_scalar_sequence_field(field)):
59-
# raise ImproperConfiguration(
60-
# f"field: '{field_name}' with annotation:'{field.type_}' in '{self.param_field.type_}'"
61-
# f"can't be processed. Field type is not a primitive type"
62-
# )
57+
if not (is_scalar_field(field=field) or is_scalar_sequence_field(field)):
58+
raise ImproperConfiguration(
59+
f"field: '{field_name}' with annotation:'{field.type_}' in '{self.param_field.type_}'"
60+
f"can't be processed. Field type is not a primitive type"
61+
)
6362

6463
def get_parameter_field(
6564
self,
@@ -97,42 +96,38 @@ def generate_resolvers(self, body_field_class: t.Type[FieldInfo]) -> None:
9796
field_info=field,
9897
)
9998
self.validate(k, model_field)
100-
field_info = model_field.field_info
10199

102-
if not isinstance(model_field.field_info, params.ParamFieldInfo):
103-
convert_underscores = getattr(
104-
self.param_field.field_info,
100+
convert_underscores = getattr(
101+
self.param_field.field_info,
102+
"convert_underscores",
103+
getattr(
104+
self.param_field.field_info.json_schema_extra,
105105
"convert_underscores",
106-
getattr(
107-
self.param_field.field_info.json_schema_extra,
108-
"convert_underscores",
109-
None,
110-
),
111-
)
112-
113-
keys = dict(
114-
FieldConstraintsDefaultValues,
115-
**model_field.field_info.extract_attributes_keys()
116-
if hasattr(model_field.field_info, "extract_attributes_keys")
117-
else {},
118-
**{"convert_underscores": convert_underscores}
119-
if convert_underscores
120-
else {},
121-
)
122-
123-
attrs = {
124-
k: getattr(model_field.field_info, k, v) for k, v in keys.items()
125-
}
126-
127-
model_field, field_info = self.get_parameter_field(
128-
k, model_field, attrs, body_field_class
129-
)
130-
resolver = field_info.create_resolver(model_field=model_field) # type:ignore[attr-defined]
106+
None,
107+
),
108+
)
109+
110+
keys = dict(
111+
FieldConstraintsDefaultValues,
112+
**model_field.field_info.extract_attributes_keys()
113+
if hasattr(model_field.field_info, "extract_attributes_keys")
114+
else {},
115+
**{"convert_underscores": convert_underscores}
116+
if convert_underscores
117+
else {},
118+
)
119+
120+
attrs = {k: getattr(model_field.field_info, k, v) for k, v in keys.items()}
121+
122+
model_field, field_info = self.get_parameter_field(
123+
k, model_field, attrs, body_field_class
124+
)
125+
resolver = field_info.create_resolver(model_field=model_field)
131126
resolvers.append(resolver)
132127

133128
if isinstance(self.param_field.field_info.json_schema_extra, dict):
134129
self.param_field.field_info.json_schema_extra[MULTI_RESOLVER_KEY] = (
135-
resolvers
130+
resolvers # type:ignore[assignment]
136131
)
137132

138133

0 commit comments

Comments
 (0)