|
9 | 9 | # cython: optimize.unpack_method_calls=True |
10 | 10 | # cython: infer_types=True |
11 | 11 |
|
| 12 | +import dataclasses |
| 13 | +import inspect |
12 | 14 | import json |
13 | 15 | import logging |
14 | 16 | import sys |
@@ -239,45 +241,65 @@ cdef class InputFilter: |
239 | 241 | """ |
240 | 242 | self.conditions.append(condition) |
241 | 243 |
|
242 | | - cdef void _register_decorator_components(self): |
243 | | - """Register decorator-based components from the current class only.""" |
244 | | - cdef object cls, attr_value, conditions, validators, filters |
| 244 | + cpdef void _register_decorator_components(self): |
| 245 | + """Register decorator-based components from the current class and |
| 246 | + inheritance chain.""" |
| 247 | + cdef object cls, attr_value, conditions, validators, filters, base_cls |
245 | 248 | cdef str attr_name |
246 | 249 | cdef list dir_attrs |
247 | 250 | cdef FieldDescriptor field_desc |
| 251 | + cdef set added_conditions, added_global_validators, added_global_filters |
| 252 | + cdef object condition_id, validator_id, filter_id |
| 253 | + cdef object condition, validator, filter_instance |
248 | 254 |
|
249 | 255 | cls = self.__class__ |
250 | 256 | dir_attrs = dir(cls) |
251 | 257 |
|
252 | 258 | for attr_name in dir_attrs: |
253 | | - if (<bytes>attr_name.encode('utf-8')).startswith(b"_"): |
| 259 | + if attr_name.startswith("_"): |
254 | 260 | continue |
| 261 | + if hasattr(cls, attr_name): |
| 262 | + attr_value = getattr(cls, attr_name) |
| 263 | + if isinstance(attr_value, FieldDescriptor): |
| 264 | + self.fields[attr_name] = FieldModel( |
| 265 | + attr_value.required, |
| 266 | + attr_value.default, |
| 267 | + attr_value.fallback, |
| 268 | + attr_value.filters, |
| 269 | + attr_value.validators, |
| 270 | + attr_value.steps, |
| 271 | + attr_value.external_api, |
| 272 | + attr_value.copy, |
| 273 | + ) |
255 | 274 |
|
256 | | - attr_value = getattr(cls, attr_name, None) |
257 | | - if attr_value is not None and isinstance(attr_value, FieldDescriptor): |
258 | | - field_desc = <FieldDescriptor>attr_value |
259 | | - self.fields[attr_name] = FieldModel( |
260 | | - field_desc.required, |
261 | | - field_desc._default, |
262 | | - field_desc.fallback, |
263 | | - field_desc.filters, |
264 | | - field_desc.validators, |
265 | | - field_desc.steps, |
266 | | - field_desc.external_api, |
267 | | - field_desc.copy, |
268 | | - ) |
269 | | - |
270 | | - conditions = getattr(cls, "_conditions", None) |
271 | | - if conditions is not None: |
272 | | - self.conditions.extend(conditions) |
273 | | - |
274 | | - validators = getattr(cls, "_global_validators", None) |
275 | | - if validators is not None: |
276 | | - self.global_validators.extend(validators) |
277 | | - |
278 | | - filters = getattr(cls, "_global_filters", None) |
279 | | - if filters is not None: |
280 | | - self.global_filters.extend(filters) |
| 275 | + added_conditions = set() |
| 276 | + added_global_validators = set() |
| 277 | + added_global_filters = set() |
| 278 | + |
| 279 | + for base_cls in reversed(cls.__mro__): |
| 280 | + conditions = getattr(base_cls, "_conditions", None) |
| 281 | + if conditions is not None: |
| 282 | + for condition in conditions: |
| 283 | + condition_id = id(condition) |
| 284 | + if condition_id not in added_conditions: |
| 285 | + self.conditions.append(condition) |
| 286 | + added_conditions.add(condition_id) |
| 287 | + |
| 288 | + validators = getattr(base_cls, "_global_validators", None) |
| 289 | + if validators is not None: |
| 290 | + for validator in validators: |
| 291 | + validator_id = id(validator) |
| 292 | + if validator_id not in added_global_validators: |
| 293 | + self.global_validators.append(validator) |
| 294 | + added_global_validators.add(validator_id) |
| 295 | + |
| 296 | + filters = getattr(base_cls, "_global_filters", None) |
| 297 | + if filters is not None: |
| 298 | + for filter_instance in filters: |
| 299 | + filter_id = id(filter_instance) |
| 300 | + if filter_id not in added_global_filters: |
| 301 | + self.global_filters.append(filter_instance) |
| 302 | + added_global_filters.add(filter_id) |
281 | 303 |
|
282 | 304 | self.model_class = getattr(cls, "_model", self.model_class) |
283 | 305 |
|
@@ -731,7 +753,26 @@ cdef class InputFilter: |
731 | 753 | if self.model_class is None: |
732 | 754 | return self.validated_data |
733 | 755 |
|
734 | | - return self.model_class(**self.validated_data) |
| 756 | + try: |
| 757 | + return self.model_class(**self.validated_data) |
| 758 | + except TypeError: |
| 759 | + pass |
| 760 | + |
| 761 | + cdef set field_names = set() |
| 762 | + |
| 763 | + if dataclasses.is_dataclass(self.model_class): |
| 764 | + field_names = {f.name for f in dataclasses.fields(self.model_class)} |
| 765 | + elif hasattr(self.model_class, '__fields__'): |
| 766 | + field_names = set(self.model_class.__fields__.keys()) |
| 767 | + elif hasattr(self.model_class, '__annotations__'): |
| 768 | + field_names = set(self.model_class.__annotations__.keys()) |
| 769 | + else: |
| 770 | + sig = inspect.signature(self.model_class.__init__) |
| 771 | + field_names = set(sig.parameters.keys()) - {'self'} |
| 772 | + |
| 773 | + cdef dict filtered_data = {k: v for k, v in self.validated_data.items() if k in field_names} |
| 774 | + |
| 775 | + return self.model_class(**filtered_data) |
735 | 776 |
|
736 | 777 | cpdef void add_global_validator(self, BaseValidator validator): |
737 | 778 | """ |
|
0 commit comments