diff --git a/AGENTS.md b/AGENTS.md index a2181ff..9109810 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -22,7 +22,8 @@ poetry add flask-inputfilter Create validation schemas by inheriting from `InputFilter`: ```python -from flask_inputfilter import InputFilter, field +from flask_inputfilter import InputFilter +from flask_inputfilter.declarative import field, global_filter from flask_inputfilter.filters import ToIntegerFilter, StringTrimFilter from flask_inputfilter.validators import IsIntegerValidator, LengthValidator @@ -41,8 +42,7 @@ class UserInputFilter(InputFilter): ) # Global filters/validators apply to all fields - _global_filters = [StringTrimFilter()] - _global_validators = [] + global_filter(StringTrimFilter()) ``` ### 2. Usage in Flask Routes diff --git a/README.md b/README.md index 2b5895e..a942a06 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ A more detailed guide can be found [in the docs](https://leandercs.github.io/fla ```python from flask_inputfilter import InputFilter -from flask_inputfilter.declarative import field +from flask_inputfilter.declarative import condition, field from flask_inputfilter.conditions import ExactlyOneOfCondition from flask_inputfilter.enums import RegexEnum from flask_inputfilter.filters import StringTrimFilter, ToIntegerFilter, ToNullFilter @@ -101,9 +101,9 @@ class UpdateZipcodeInputFilter(InputFilter): ] ) - _conditions = [ + condition( ExactlyOneOfCondition(['zipcode', 'city']) - ] + ) ``` diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 1b3a7b5..d862a5c 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -4,6 +4,53 @@ Changelog All notable changes to this project will be documented in this file. +[0.7.2] - 2025-09-28 +-------------------- + +Changed +^^^^^^^ +- Changed the way how to use the new decorator ``_condition``, ``_global_filter``, ``_global_validator`` and ``_model``. + They should no longer be assigned to a variable, but should be set with the corresponding declarative method. + + ``_condition = [Example()]`` => ``condition(Example())`` + + ``_global_filter = [Example()]`` => ``global_filter(Example())`` + + ``_global_validator = [Example()]`` => ``global_validator(Example())`` + + ``_model = Example`` => ``model(Example)`` + + The previous way is still supported, but is not recommended because it is not streightforward and could lead to confusion. + The new methods also support multiple calls and also mass assignment. + + Both of the following examples are valid and have the same effect: + + .. code-block:: python + + class ExampleInputFilter(InputFilter): + field1: str = field() + field2: str = field() + condition(ExactlyOneOfCondition(['field1', 'field2'])) + + field3: str = field() + field4: str = field() + condition(AtLeastOneOfCondition(['field3', 'field4'])) + + + .. code-block:: python + + class ExampleInputFilter(InputFilter): + field1: str = field() + field2: str = field() + field3: str = field() + field4: str = field() + + condition( + ExactlyOneOfCondition(['field1', 'field2']), + AtLeastOneOfCondition(['field3', 'field4']) + ) + + [0.7.1] - 2025-09-27 -------------------- @@ -24,13 +71,13 @@ Added ``self.add`` => ``field`` - ``self.add_condition`` => ``_conditions`` + ``self.add_condition`` => ``condition`` - ``self.add_global_filter`` => ``_global_filters`` + ``self.add_global_filter`` => ``global_filter`` - ``self.add_global_validator`` => ``_global_validators`` + ``self.add_global_validator`` => ``global_validator`` - ``self.add_model`` => ``_model`` + ``self.add_model`` => ``model`` **Before**: .. code-block:: python @@ -65,13 +112,13 @@ Added validators=[IsIntegerValidator()] ) - _conditions = [ExactlyOneOfCondition(['zipcode', 'city'])] + condition(ExactlyOneOfCondition(['zipcode', 'city'])) - _global_filters = [StringTrimFilter()] + global_filter(StringTrimFilter()) - _global_validators = [IsStringValidator()] + global_validator(IsStringValidator()) - _model = UserModel + model(UserModel) The Change is fully backward compatible, but the new way is more readable and maintainable. diff --git a/docs/source/index.rst b/docs/source/index.rst index 266fa3f..f7eb859 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -77,7 +77,7 @@ Definition validators=[IsStringValidator()] ) - _conditions = [ExactlyOneOfCondition(['zipcode', 'city'])] + condition(ExactlyOneOfCondition(['zipcode', 'city'])) Usage ^^^^^ diff --git a/docs/source/options/condition.rst b/docs/source/options/condition.rst index 031b23c..e07e519 100644 --- a/docs/source/options/condition.rst +++ b/docs/source/options/condition.rst @@ -6,7 +6,7 @@ Condition Overview -------- -Conditions are added using the ``_conditions`` class attribute. They evaluate the combined input data, ensuring that inter-field dependencies and relationships (such as equality, ordering, or presence) meet predefined rules. +Conditions are added using the ``condition()`` declarative. They evaluate the combined input data, ensuring that inter-field dependencies and relationships (such as equality, ordering, or presence) meet predefined rules. Example ------- @@ -24,7 +24,7 @@ Example validators=[IsStringValidator()] ) - _conditions = [OneOfCondition(['id', 'name'])] + condition(OneOfCondition(['id', 'name'])) Available Conditions -------------------- diff --git a/docs/source/options/declarative_api.rst b/docs/source/options/declarative_api.rst new file mode 100644 index 0000000..a75eb22 --- /dev/null +++ b/docs/source/options/declarative_api.rst @@ -0,0 +1,227 @@ +Declarative API +=============== + +Overview +-------- + +The Declarative API is the modern, recommended way to define InputFilters in flask-inputfilter. +It uses Python decorators and class-level declarations to create clean, readable, and maintainable +input validation definitions. + +Key Features +------------ + +- **Clean Syntax**: Define fields, conditions, and global components directly in class definition +- **Type Safety**: Integrates well with type hints and IDEs +- **Inheritance Support**: Full support for class inheritance and MRO +- **Model Integration**: Automatic serialization to dataclasses, Pydantic models, and more +- **Multiple Element Support**: Register multiple components at once for concise definitions + +Core Components +--------------- + +The Declarative API consists of four main decorators: + ++----------------------+--------------------------------------------+ +| Decorator | Purpose | ++======================+============================================+ +| ``field()`` | Define individual input fields | ++----------------------+--------------------------------------------+ +| ``condition()`` | Add cross-field validation conditions | ++----------------------+--------------------------------------------+ +| ``global_filter()`` | Add filters applied to all fields | ++----------------------+--------------------------------------------+ +| ``global_validator()``| Add validators applied to all fields | ++----------------------+--------------------------------------------+ +| ``model()`` | Associate with a model class | ++----------------------+--------------------------------------------+ + +Quick Example +------------- + +.. code-block:: python + + from flask_inputfilter import InputFilter + from flask_inputfilter.declarative import field, condition, global_filter, global_validator, model + from flask_inputfilter.filters import StringTrimFilter, ToLowerFilter + from flask_inputfilter.validators import IsStringValidator, LengthValidator, EmailValidator + from flask_inputfilter.conditions import EqualCondition + from dataclasses import dataclass + + @dataclass + class User: + username: str + email: str + password: str + + class UserRegistrationFilter(InputFilter): + # Field definitions with individual configuration + username = field( + required=True, + validators=[LengthValidator(min_length=3, max_length=20)] + ) + + email = field( + required=True, + validators=[EmailValidator()] + ) + + password = field(required=True, validators=[LengthValidator(min_length=8)]) + password_confirmation = field(required=True) + + # Global components - applied to all fields + global_filter(StringTrimFilter(), ToLowerFilter()) + global_validator(IsStringValidator()) + + # Cross-field validation + condition(EqualCondition('password', 'password_confirmation')) + + # Model association + model(User) + + Declarative API +---------------- + +.. code-block:: python + + class MyFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) + + global_filter(StringTrimFilter()) + condition(RequiredCondition('name')) + +Inheritance and MRO +------------------- + +The Declarative API fully supports Python's inheritance and Method Resolution Order (MRO): + +.. code-block:: python + + class BaseUserFilter(InputFilter): + # Base fields + name = field(required=True, validators=[IsStringValidator()]) + + # Base global components + global_filter(StringTrimFilter()) + + class ExtendedUserFilter(BaseUserFilter): + # Additional fields + email = field(required=True, validators=[EmailValidator()]) + age = field(required=False, validators=[IsIntegerValidator()]) + + # Additional global components (inherited ones are preserved) + global_validator(LengthValidator(min_length=1)) + + # Conditions + condition(RequiredCondition('email')) + +Field Override +~~~~~~~~~~~~~~ + +You can override fields from parent classes: + +.. code-block:: python + + class BaseFilter(InputFilter): + name = field(required=False) # Optional in base + + class StrictFilter(BaseFilter): + name = field(required=True, validators=[LengthValidator(min_length=2)]) # Override + +Multiple Element Registration +----------------------------- + +You can register multiple components at once for cleaner definitions: + +.. code-block:: python + + class CompactFilter(InputFilter): + name = field(required=True) + email = field(required=True) + + # Multiple global filters + global_filter(StringTrimFilter(), ToLowerFilter(), RemoveExtraSpacesFilter()) + + # Multiple global validators + global_validator(IsStringValidator(), NotEmptyValidator()) + + # Multiple conditions + condition( + RequiredCondition('name'), + RequiredCondition('email'), + EqualCondition('password', 'password_confirmation') + ) + +Model Integration +----------------- + +The Declarative API seamlessly integrates with various model types: + +Dataclasses +~~~~~~~~~~~ + +.. code-block:: python + + from dataclasses import dataclass + + @dataclass + class User: + name: str + email: str + + class UserFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) + + model(User) + + # Usage + filter_instance = UserFilter() + user = filter_instance.validate_data({'name': 'John', 'email': 'john@example.com'}) + # user is a User dataclass instance + +Pydantic Models +~~~~~~~~~~~~~~~ + +.. code-block:: python + + from pydantic import BaseModel + + class User(BaseModel): + name: str + email: str + + class UserFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) + + model(User) + +TypedDict +~~~~~~~~~ + +.. code-block:: python + + from typing import TypedDict + + class UserDict(TypedDict): + name: str + email: str + + class UserFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) + + model(UserDict) + +Next Steps +---------- + +For detailed information about each component, see: + +- :doc:`Field Decorator ` - Complete field configuration options +- :doc:`Global Decorators ` - Global filters, validators, and conditions +- :doc:`Conditions ` - Cross-field validation conditions +- :doc:`Filters ` - Available filters and custom filter creation +- :doc:`Validators ` - Available validators and custom validator creation diff --git a/docs/source/options/deserialization.rst b/docs/source/options/deserialization.rst index 3ee072b..e90f9f0 100644 --- a/docs/source/options/deserialization.rst +++ b/docs/source/options/deserialization.rst @@ -11,9 +11,8 @@ Overview The deserialization process is handled through two main methods: -- ``set_model()``: Sets the model class to be used for deserialization -- ``serialize()``: Converts the validated data into an instance of the - specified model class or returns the raw data as a dictionary +- ``model()``, ``set_model()``: Sets the model class to be used for deserialization +- ``validate()``: Validates the input data and deserializes it into an instance of the model class if set Configuration ------------- @@ -36,7 +35,7 @@ into an instance of the model class, if there is a model class set. username: str = field() email: str = field() - _model = User + model(User) Examples -------- @@ -59,7 +58,7 @@ You can also use deserialization in your Flask routes: class MyInputFilter(InputFilter): username: str = field() - _model = User + model(User) app = Flask(__name__) @@ -89,7 +88,7 @@ You can also use deserialization outside of Flask routes: class MyInputFilter(InputFilter): username: str = field() - _model = User + model(User) app = Flask(__name__) @@ -101,4 +100,4 @@ You can also use deserialization outside of Flask routes: if not input_filter.is_valid(): return jsonify({"error": "Invalid data"}), 400 - validated_data: User = input_filter.serialize() + validated_data: User = input_filter.get_values() diff --git a/docs/source/options/field_decorator.rst b/docs/source/options/field_decorator.rst new file mode 100644 index 0000000..f6b479c --- /dev/null +++ b/docs/source/options/field_decorator.rst @@ -0,0 +1,389 @@ +Field Decorator +=============== + +Overview +-------- + +The ``field()`` decorator is the core component for defining individual input fields in the +Declarative API. It provides comprehensive configuration options for field validation, +filtering, and processing. + +.. autofunction:: flask_inputfilter.declarative.field.field + +Basic Usage +----------- + +The simplest field definition: + +.. code-block:: python + + from flask_inputfilter import InputFilter + from flask_inputfilter.declarative import field + + class SimpleFilter(InputFilter): + name = field() # Optional field with no validation + +Required Fields +~~~~~~~~~~~~~~~ + +.. code-block:: python + + class UserFilter(InputFilter): + username = field(required=True) + email = field(required=True) + +Field Configuration Options +--------------------------- + +The ``field()`` decorator accepts the following parameters: + +required +~~~~~~~~ + +**Type**: ``bool`` +**Default**: ``False`` + +Specifies whether the field must be present in the input data. + +.. code-block:: python + + class RegistrationFilter(InputFilter): + username = field(required=True) # Must be provided + nickname = field(required=False) # Optional + +**Examples**: + +.. code-block:: python + + # Required field - will raise ValidationError if missing + email = field(required=True) + + # Optional field - won't raise error if missing + phone = field(required=False) + +default +~~~~~~~ + +**Type**: ``Any`` +**Default**: ``None`` + +Provides a default value when the field is not present in the input data. +Only applies to optional fields (``required=False``). + +.. code-block:: python + + class UserFilter(InputFilter): + name = field(required=True) + role = field(required=False, default="user") + active = field(required=False, default=True) + tags = field(required=False, default=[]) + +**Important**: Use ``default=[]`` carefully with mutable objects. Consider using a factory function: + +.. code-block:: python + + from flask_inputfilter.filters import DefaultValueFilter + + class SafeDefaultFilter(InputFilter): + # Safe for mutable defaults + tags = field(required=False, filters=[DefaultValueFilter(lambda: [])]) + +fallback +~~~~~~~~ + +**Type**: ``Any`` +**Default**: ``None`` + +Specifies a value to use when validation fails or when processing encounters errors. +Unlike ``default``, ``fallback`` applies even when the field is present but invalid. + +.. code-block:: python + + from flask_inputfilter.validators import IsIntegerValidator + + class RobustFilter(InputFilter): + age = field( + required=False, + validators=[IsIntegerValidator()], + fallback=0 # Use 0 if validation fails + ) + + priority = field( + required=True, + validators=[IsIntegerValidator()], + fallback=1 # Use 1 if provided value is invalid + ) + +filters +~~~~~~~ + +**Type**: ``list[BaseFilter]`` +**Default**: ``[]`` + +List of filters to apply to the field value. Filters are applied in the order specified +and transform the input data before validation. + +.. code-block:: python + + from flask_inputfilter.filters import StringTrimFilter, ToLowerFilter, ToIntegerFilter + + class FilteredInputFilter(InputFilter): + username = field( + required=True, + filters=[StringTrimFilter(), ToLowerFilter()] + ) + + age = field( + required=False, + filters=[ToIntegerFilter()] + ) + +**Common Filters**: + +.. code-block:: python + + from flask_inputfilter.filters import ( + StringTrimFilter, # Remove leading/trailing whitespace + ToLowerFilter, # Convert to lowercase + ToUpperFilter, # Convert to uppercase + ToIntegerFilter, # Convert to integer + ToFloatFilter, # Convert to float + ToBooleanFilter, # Convert to boolean + ToNullFilter, # Convert empty strings to None + RemoveFilter, # Remove specific characters + ReplaceFilter, # Replace characters/patterns + ) + + class ComprehensiveFilter(InputFilter): + email = field(filters=[StringTrimFilter(), ToLowerFilter()]) + price = field(filters=[ToFloatFilter()]) + active = field(filters=[ToBooleanFilter()]) + +validators +~~~~~~~~~~ + +**Type**: ``list[BaseValidator]`` +**Default**: ``[]`` + +List of validators to apply to the field value. Validators check the processed value +and raise validation errors if the value doesn't meet the criteria. + +.. code-block:: python + + from flask_inputfilter.validators import ( + IsStringValidator, LengthValidator, EmailValidator, RegexValidator + ) + + class ValidatedFilter(InputFilter): + username = field( + required=True, + validators=[ + IsStringValidator(), + LengthValidator(min_length=3, max_length=20) + ] + ) + + email = field( + required=True, + validators=[IsStringValidator(), EmailValidator()] + ) + +**Common Validators**: + +.. code-block:: python + + from flask_inputfilter.validators import ( + IsStringValidator, # Must be a string + IsIntegerValidator, # Must be an integer + IsFloatValidator, # Must be a float + IsBooleanValidator, # Must be a boolean + IsListValidator, # Must be a list + IsDictValidator, # Must be a dictionary + LengthValidator, # String/list length validation + RangeValidator, # Numeric range validation + EmailValidator, # Email format validation + UrlValidator, # URL format validation + RegexValidator, # Custom regex validation + InValidator, # Value must be in allowed list + NotInValidator, # Value must not be in forbidden list + ) + +**Validator Examples**: + +.. code-block:: python + + class DetailedValidationFilter(InputFilter): + # String validation with length constraints + password = field( + required=True, + validators=[ + IsStringValidator(), + LengthValidator(min_length=8, max_length=128) + ] + ) + + # Numeric validation with range + age = field( + required=True, + validators=[ + IsIntegerValidator(), + RangeValidator(min_value=13, max_value=120) + ] + ) + + # Choice validation + status = field( + required=True, + validators=[ + IsStringValidator(), + InValidator(['active', 'inactive', 'pending']) + ] + ) + + # Custom regex validation + phone = field( + required=False, + validators=[ + IsStringValidator(), + RegexValidator(r'^\+?1?\d{9,15}$', message="Invalid phone number format") + ] + ) + +steps +~~~~~ + +**Type**: ``list`` +**Default**: ``[]`` + +Defines a sequence of processing steps (filters and validators) that are applied in order. +This allows for fine-grained control over the processing pipeline. + +.. code-block:: python + + from flask_inputfilter.filters import StringTrimFilter, ToIntegerFilter + from flask_inputfilter.validators import IsStringValidator, IsIntegerValidator + + class SteppedFilter(InputFilter): + numeric_input = field( + required=True, + steps=[ + StringTrimFilter(), # Step 1: Remove whitespace + IsStringValidator(), # Step 2: Validate it's a string + ToIntegerFilter(), # Step 3: Convert to integer + IsIntegerValidator(), # Step 4: Validate it's an integer + RangeValidator(min_value=0) # Step 5: Validate range + ] + ) + +external_api +~~~~~~~~~~~~ + +**Type**: ``ExternalApiConfig`` +**Default**: ``None`` + +Configuration for fetching field values from external APIs. Useful for data enrichment +or validation against external services. + +.. code-block:: python + + from flask_inputfilter.models import ExternalApiConfig + + class ExternalDataFilter(InputFilter): + user_id = field(required=True, validators=[IsIntegerValidator()]) + + # Fetch user details from external API + user_profile = field( + required=False, + external_api=ExternalApiConfig( + url="https://api.example.com/users/{user_id}", + method="GET", + headers={"Authorization": "Bearer token"} + ) + ) + +For detailed information, see :doc:`ExternalApi `. + +copy +~~~~ + +**Type**: ``str`` +**Default**: ``None`` + +Copy the value from another field. The copied value can then be filtered and validated +independently. + +.. code-block:: python + + class CopyFieldFilter(InputFilter): + email = field(required=True, validators=[EmailValidator()]) + + # Copy email value for confirmation + email_confirmation = field( + required=True, + copy="email", + validators=[EmailValidator()] + ) + + # Processing happens after copying + normalized_email = field( + required=False, + copy="email", + filters=[StringTrimFilter(), ToLowerFilter()] + ) + +For detailed information, see :doc:`Copy `. + +Advanced Field Patterns +----------------------- + +Multi-Type Fields +~~~~~~~~~~~~~~~~~ + +Fields that can accept multiple types: + +.. code-block:: python + + from flask_inputfilter.validators import OrValidator + + class FlexibleTypeFilter(InputFilter): + # Can be either string or integer + identifier = field( + required=True, + validators=[ + OrValidator([ + IsStringValidator(), + IsIntegerValidator() + ]) + ] + ) + +Custom Field Processing +~~~~~~~~~~~~~~~~~~~~~~~ + +Create reusable field configurations: + +.. code-block:: python + + # Define reusable field types + def email_field(required=True): + return field( + required=required, + filters=[StringTrimFilter(), ToLowerFilter()], + validators=[IsStringValidator(), EmailValidator()] + ) + + def username_field(min_length=3, max_length=20): + return field( + required=True, + filters=[StringTrimFilter()], + validators=[ + IsStringValidator(), + LengthValidator(min_length=min_length, max_length=max_length), + RegexValidator(r'^[a-zA-Z0-9_]+$', message="Only letters, numbers, and underscores allowed") + ] + ) + + class UserFilter(InputFilter): + username = username_field() + email = email_field() + backup_email = email_field(required=False) diff --git a/docs/source/options/filter.rst b/docs/source/options/filter.rst index 46df7d0..f05a34e 100644 --- a/docs/source/options/filter.rst +++ b/docs/source/options/filter.rst @@ -6,7 +6,7 @@ Filter Overview -------- -Filters can be added to specific fields using the decorator syntax or as global filters using ``_global_filters``. +Filters can be added to specific fields using the decorator syntax or as global filters using ``global_filter()``. The global filters will be executed before the specific field filtering. @@ -26,7 +26,7 @@ Example filters=[StringTrimFilter()] ) - _global_filters = [ToLowerFilter()] + global_filter(ToLowerFilter()) Available Filters ----------------- diff --git a/docs/source/options/global_decorators.rst b/docs/source/options/global_decorators.rst new file mode 100644 index 0000000..fa53965 --- /dev/null +++ b/docs/source/options/global_decorators.rst @@ -0,0 +1,372 @@ +Global Decorators +================= + +Overview +-------- + +Global decorators apply filters, validators, and conditions to all fields in an InputFilter +or establish cross-field validation rules. They provide a convenient way to implement +common processing logic without repeating configuration for each field. + +The global decorators are: + +- ``global_filter()`` - Apply filters to all fields +- ``global_validator()`` - Apply validators to all fields +- ``condition()`` - Add cross-field validation conditions +- ``model()`` - Associate with a model class for serialization + +All global decorators support multiple element registration for concise definitions. + +global_filter() +--------------- + +Applies filters to all fields in the InputFilter before individual field filters are processed. + +.. autofunction:: flask_inputfilter.declarative.global_filter.global_filter + +Basic Usage +~~~~~~~~~~~ + +.. code-block:: python + + from flask_inputfilter import InputFilter + from flask_inputfilter.declarative import field, global_filter + from flask_inputfilter.filters import StringTrimFilter, ToLowerFilter + + class TrimmedInputFilter(InputFilter): + name = field(required=True) + email = field(required=True) + description = field(required=False) + + # Apply to all fields + global_filter(StringTrimFilter()) + +Multiple Global Filters +~~~~~~~~~~~~~~~~~~~~~~~~ + +Register multiple filters in a single call: + +.. code-block:: python + + class ProcessedInputFilter(InputFilter): + username = field(required=True) + email = field(required=True) + bio = field(required=False) + + # Multiple filters applied to all fields + global_filter(StringTrimFilter(), ToLowerFilter(), RemoveExtraSpacesFilter()) + +**Processing Order**: Global filters are applied first, then individual field filters. + +.. code-block:: python + + class OrderDemoFilter(InputFilter): + name = field( + required=True, + filters=[ToUpperFilter()] # Applied after global filters + ) + + # Applied first to all fields + global_filter(StringTrimFilter()) + + # Input: " john " → StringTrimFilter() → "john" → ToUpperFilter() → "JOHN" + +Inheritance +~~~~~~~~~~~ + +Global filters are inherited and preserved across class hierarchies: + +.. code-block:: python + + class BaseFilter(InputFilter): + global_filter(StringTrimFilter()) + + class ExtendedFilter(BaseFilter): + name = field(required=True) + + # Additional global filter (StringTrimFilter is preserved) + global_filter(ToLowerFilter()) + + # ExtendedFilter has both StringTrimFilter and ToLowerFilter + +global_validator() +------------------ + +Applies validators to all fields in the InputFilter after filters have been processed. + +.. autofunction:: flask_inputfilter.declarative.global_validator.global_validator + +Basic Usage +~~~~~~~~~~~ + +.. code-block:: python + + from flask_inputfilter.validators import IsStringValidator, NotEmptyValidator + + class ValidatedInputFilter(InputFilter): + name = field(required=True) + email = field(required=True) + description = field(required=False) + + # All fields must be strings and not empty + global_validator(IsStringValidator()) + +Multiple Global Validators +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class StrictInputFilter(InputFilter): + username = field(required=True) + email = field(required=True) + password = field(required=True) + + # Multiple validators applied to all fields + global_validator( + IsStringValidator(), + NotEmptyValidator(), + LengthValidator(min_length=1, max_length=255) + ) + +Inheritance and Override +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class BaseValidatedFilter(InputFilter): + global_validator(IsStringValidator()) + + class StrictFilter(BaseValidatedFilter): + name = field(required=True) + + # Additional validators (IsStringValidator is preserved) + global_validator(LengthValidator(min_length=2)) + +Practical Examples +~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Ensure all text fields are valid strings + class TextOnlyFilter(InputFilter): + global_validator(IsStringValidator(), NotEmptyValidator()) + + # Length constraints for all fields + class BoundedFilter(InputFilter): + global_validator(LengthValidator(max_length=1000)) + + # Security validation + class SecureInputFilter(InputFilter): + global_validator( + NoScriptTagValidator(), + NoSqlInjectionValidator(), + NoXssValidator() + ) + +condition() +----------- + +Adds cross-field validation conditions that validate relationships between multiple fields. + +.. autofunction:: flask_inputfilter.declarative.condition.condition + +Basic Usage +~~~~~~~~~~~ + +.. code-block:: python + + from flask_inputfilter.conditions import EqualCondition + + class RegistrationFilter(InputFilter): + password = field(required=True) + password_confirmation = field(required=True) + + # Password confirmation validation + condition(EqualCondition('password', 'password_confirmation')) + +Multiple Conditions +~~~~~~~~~~~~~~~~~~~ + +Register multiple conditions in a single call: + +.. code-block:: python + + from flask_inputfilter.conditions import ( + EqualCondition, AtLeastOneOfCondition, ExactlyOneOfCondition + ) + + class ComplexValidationFilter(InputFilter): + password = field(required=True) + password_confirmation = field(required=True) + email = field(required=False) + phone = field(required=False) + address = field(required=False) + + # Multiple cross-field validations + condition( + EqualCondition('password', 'password_confirmation'), + AtLeastOneOfCondition(['email', 'phone']), # Need at least one contact method + ExactlyOneOfCondition(['email', 'phone', 'address']) # But only one primary contact + ) + +model() +------- + +Associates the InputFilter with a model class for automatic serialization. + +.. autofunction:: flask_inputfilter.declarative.model.model + +Basic Usage +~~~~~~~~~~~ + +.. code-block:: python + + from dataclasses import dataclass + from flask_inputfilter.declarative import model + + @dataclass + class User: + username: str + email: str + + class UserFilter(InputFilter): + username = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) + + model(User) + + # Usage + filter_instance = UserFilter() + user = filter_instance.validate_data({'username': 'john', 'email': 'john@example.com'}) + # user is a User dataclass instance + +Supported Model Types +~~~~~~~~~~~~~~~~~~~~~ + +**Dataclasses**: + +.. code-block:: python + + from dataclasses import dataclass + + @dataclass + class Product: + name: str + price: float + in_stock: bool = True + + class ProductFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + price = field(required=True, filters=[ToFloatFilter()]) + in_stock = field(required=False, filters=[ToBooleanFilter()], default=True) + + model(Product) + +**Pydantic Models**: + +.. code-block:: python + + from pydantic import BaseModel + + class User(BaseModel): + username: str + email: str + age: int = None + + class UserFilter(InputFilter): + username = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) + age = field(required=False, filters=[ToIntegerFilter()]) + + model(User) + +**TypedDict**: + +.. code-block:: python + + from typing import TypedDict + + class UserDict(TypedDict): + username: str + email: str + + class UserFilter(InputFilter): + username = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) + + model(UserDict) + +Field Filtering +~~~~~~~~~~~~~~~ + +The model decorator automatically filters out fields that don't exist in the model: + +.. code-block:: python + + @dataclass + class SimpleUser: + name: str # Only has name field + + class ExtendedUserFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[EmailValidator()]) # Extra field + age = field(required=False, filters=[ToIntegerFilter()]) # Extra field + + model(SimpleUser) + + # Only 'name' will be passed to SimpleUser constructor + # 'email' and 'age' are filtered out automatically + +Inheritance and Advanced Usage +------------------------------ + +Combining Global Decorators +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class FullFeaturedFilter(InputFilter): + username = field(required=True) + email = field(required=True) + password = field(required=True) + password_confirmation = field(required=True) + phone = field(required=False) + + # Global processing + global_filter(StringTrimFilter(), ToLowerFilter()) + global_validator(IsStringValidator(), NotEmptyValidator()) + + # Cross-field validation + condition( + EqualCondition('password', 'password_confirmation'), + AtLeastOneOfCondition(['email', 'phone']) + ) + + # Model association + model(User) + +Hierarchical Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class BaseUserFilter(InputFilter): + # Base global configuration + global_filter(StringTrimFilter()) + global_validator(IsStringValidator()) + + class StandardUserFilter(BaseUserFilter): + username = field(required=True) + email = field(required=True) + + # Additional processing + global_validator(NotEmptyValidator()) + + class AdminUserFilter(StandardUserFilter): + role = field(required=True, default="admin") + permissions = field(required=False, default=[]) + + # Admin-specific validation + global_validator(SecurityValidator()) + condition(AdminPermissionCondition()) + # errors will contain both field-level and condition-level errors diff --git a/docs/source/options/index.rst b/docs/source/options/index.rst index 00b799c..f0e2e7b 100644 --- a/docs/source/options/index.rst +++ b/docs/source/options/index.rst @@ -5,6 +5,9 @@ Options :maxdepth: 2 inputfilter + declarative_api + field_decorator + global_decorators validator special_validator filter diff --git a/docs/source/options/inputfilter.rst b/docs/source/options/inputfilter.rst index 90fb6f1..dda7eca 100644 --- a/docs/source/options/inputfilter.rst +++ b/docs/source/options/inputfilter.rst @@ -4,94 +4,235 @@ InputFilter Overview -------- +The ``InputFilter`` class is the core component of flask-inputfilter that provides data validation, +filtering, and serialization capabilities for Flask applications. It supports two different approaches +for defining input filters. + .. autoclass:: flask_inputfilter.input_filter.InputFilter :members: :undoc-members: :show-inheritance: -Configuration -------------- +Defining InputFilters +--------------------- + +The declarative API uses Python decorators to define fields, conditions, and global filters/validators +directly in the class definition. + +.. code-block:: python + + from flask_inputfilter import InputFilter + from flask_inputfilter.declarative import field, condition, global_filter, global_validator + from flask_inputfilter.filters import StringTrimFilter + from flask_inputfilter.validators import IsStringValidator, LengthValidator + from flask_inputfilter.conditions import EqualCondition + + class UserRegistrationFilter(InputFilter): + # Field definitions + username = field( + required=True, + filters=[StringTrimFilter()], + validators=[IsStringValidator(), LengthValidator(min_length=3, max_length=20)] + ) + + email = field( + required=True, + validators=[IsStringValidator()] + ) + + password = field(required=True) + password_confirmation = field(required=True) + + # Global filters applied to all fields + global_filter(StringTrimFilter()) + + # Global validators applied to all fields + global_validator(IsStringValidator()) + + # Conditions for cross-field validation + condition(EqualCondition('password', 'password_confirmation')) + +For detailed information about the declarative API, see: + +- :doc:`Declarative API Overview ` +- :doc:`Field Decorator ` +- :doc:`Global Decorators ` + +Usage +----- + +Using the Validate Decorator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The most common way to use InputFilter is as a decorator on Flask routes: + +.. code-block:: python + + from flask import Flask, g + + app = Flask(__name__) + + @app.route('/register', methods=['POST']) + @UserRegistrationFilter.validate() + def register(): + # Access validated data + data = g.validated_data + + username = data['username'] + email = data['email'] -The ``add`` method supports several options: + # Your registration logic here + return {'message': 'User registered successfully'} -- `Name`_ -- `Required`_ -- `Filters`_ -- `Validators`_ -- `Default`_ -- `Fallback`_ -- `Steps`_ -- `ExternalApi`_ -- `Copy`_ +Manual Validation +~~~~~~~~~~~~~~~~~ -Name -~~~~ +You can also validate data manually: -The ``name`` option specifies the name of the field. -This is the key that will be used to access the field value in the validated data. +.. code-block:: python -Required -~~~~~~~~ + filter_instance = UserRegistrationFilter() -The ``required`` option specifies whether the field must be included in the input data. -If the field is missing, a ``ValidationError`` will be raised with an appropriate error message. + filter_instance.set_data({ + 'username': 'john_doe', + 'email': 'john@example.com', + 'password': 'secret123', + 'password_confirmation': 'secret123' + }) -Filters -~~~~~~~ + if not filter_instance.is_valid(): + print("Wrong credentials") + return -The ``filters`` option allows you to specify one or more filters to apply to the field value. -Filters are applied in the order they are defined. + print(f"Welcome {filter_instance.get_value('username')}!") -For more information view the :doc:`Filter ` documentation. +Model Serialization +~~~~~~~~~~~~~~~~~~~ -Validators +InputFilter can automatically serialize validated data to model instances: + +.. code-block:: python + + from dataclasses import dataclass + from flask_inputfilter.declarative import model + + @dataclass + class User: + username: str + email: str + + class UserFilter(InputFilter): + username = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[IsStringValidator()]) + + # Associate with model + model(User) + + filter_instance = UserFilter() + filter_instance.set_data({ + 'username': 'john_doe', + 'email': 'john@example.com' + }) + + if filter_instance.is_valid(): + # Access validated data and create model instance + user_instance = User( + username=filter_instance.get_value('username'), + email=filter_instance.get_value('email') + ) + + # user_instance is now a User object + assert isinstance(user_instance, User) + assert user_instance.username == 'john_doe' + +Key Methods +----------- + +validate() ~~~~~~~~~~ -The ``validators`` option allows you to specify one or more validators to apply to the field value. -Validators are applied in the order they are defined. +Class method that returns a decorator for Flask routes. Automatically validates request data +and stores the result in ``g.validated_data``. -For more information view the :doc:`Validator ` documentation. +.. code-block:: python -Default -~~~~~~~ + @app.route('/api/users', methods=['POST']) + @UserFilter.validate() + def create_user(): + data = g.validated_data + # Use validated data -The ``default`` option allows you to specify a default value to use if the field is not -present in the input data. +set_data(data) +~~~~~~~~~~~~~~ -Fallback -~~~~~~~~ +Instance method that sets the input data to be validated. -The ``fallback`` option specifies a value to use if validation fails or required data -is missing. Note that if the field is optional and absent, ``fallback`` will not apply; -use ``default`` in such cases. +.. code-block:: python -Steps -~~~~~ + filter_instance = UserFilter() + filter_instance.set_data({'username': 'john'}) -The ``steps`` option allows you to specify a list of different filters and validator to apply to the field value. -It respects the order of the list. +is_valid() +~~~~~~~~~~ -ExternalApi -~~~~~~~~~~~ +Instance method that validates the current data and returns ``True`` if valid, ``False`` otherwise. -The ``external_api`` option allows you to specify an external API to call for the field value. -The API call is made when the field is validated, and the response is used as the field value. +.. code-block:: python -For more information view the :doc:`ExternalApi ` documentation. + if filter_instance.is_valid(): + # Data is valid, proceed with processing + username = filter_instance.get_value('username') -Copy -~~~~ +get_value(field_name) +~~~~~~~~~~~~~~~~~~~~~ -The ``copy`` option allows you to copy the value of another field. -The copied value can be filtered and validated, due to the coping being executed first. +Instance method that returns the validated value for a specific field. -For more information view the :doc:`Copy ` documentation. +.. code-block:: python + username = filter_instance.get_value('username') + email = filter_instance.get_value('email') -Examples --------- +Error Handling +-------------- + +Flask-inputfilter provides error information when validation fails. You can access validation errors +through the filter instance: + +.. code-block:: python + + filter_instance = UserFilter() + filter_instance.set_data({'username': ''}) # Empty username + + if not filter_instance.is_valid(): + errors = filter_instance.get_errors() + # errors = {'username': 'This field is required'} + +When using the decorator, validation errors automatically return a 400 response with +the error details in JSON format. + +Best Practices +-------------- + +1. **Organize Complex Filters**: Break down complex filters into base classes using inheritance +2. **Model Integration**: Use model serialization for type-safe data handling +3. **Global Components**: Use global filters/validators for common processing (e.g., trimming strings) -Least config +.. code-block:: python + class BaseFilter(InputFilter): + # Common global filters and validators + global_filter(StringTrimFilter()) + global_validator(IsStringValidator()) -Full config + class UserFilter(BaseFilter): + username = field( + required=True, + validators=[ + LengthValidator( + min_length=3, + max_length=20, + message="Username must be between 3 and 20 characters" + ) + ] + ) diff --git a/docs/source/options/validator.rst b/docs/source/options/validator.rst index d1cc3c7..adca494 100644 --- a/docs/source/options/validator.rst +++ b/docs/source/options/validator.rst @@ -7,7 +7,7 @@ They ensure that the input data meets the required conditions before further pro Overview -------- -Validators can be added to specific fields using the decorator syntax or as global validators using ``_global_validators``. +Validators can be added to specific fields using the decorator syntax or as global validators using ``global_validator()``. The global validation will be executed before the specific field validation. @@ -24,7 +24,7 @@ Example validators=[RangeValidator(min_value=0, max_value=10)] ) - _global_validators = [IsIntegerValidator()] + global_validator(IsIntegerValidator()) Available Validators -------------------- diff --git a/examples/basic/filters/user_inputfilter.py b/examples/basic/filters/user_inputfilter.py index b5f9484..9be83aa 100644 --- a/examples/basic/filters/user_inputfilter.py +++ b/examples/basic/filters/user_inputfilter.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from flask_inputfilter import InputFilter -from flask_inputfilter.declarative import field +from flask_inputfilter.declarative import field, model from flask_inputfilter.validators import IsIntegerValidator, IsStringValidator @@ -13,7 +13,7 @@ class User: class UserInputFilter(InputFilter): - _model = User + model(User) name: str = field(required=True, validators=[IsStringValidator()]) diff --git a/flask_inputfilter/_input_filter.pxd b/flask_inputfilter/_input_filter.pxd index a8d3ed9..ac68ba5 100644 --- a/flask_inputfilter/_input_filter.pxd +++ b/flask_inputfilter/_input_filter.pxd @@ -1,4 +1,5 @@ # cython: language=c++ + from typing import Any from flask_inputfilter.models.cimports cimport BaseCondition, BaseFilter, BaseValidator, ExternalApiConfig, FieldModel @@ -71,8 +72,8 @@ cdef class InputFilter: cpdef void clear(self) cpdef void merge(self, InputFilter other) cpdef void set_model(self, object model_class) - cpdef object serialize(self) cpdef void add_global_validator(self, BaseValidator validator) cpdef list[BaseValidator] get_global_validators(self) + cdef object _serialize(self) cdef void _set_methods(self, list methods) cdef void _register_decorator_components(self) diff --git a/flask_inputfilter/_input_filter.pyx b/flask_inputfilter/_input_filter.pyx index 8c97cde..f17fa80 100644 --- a/flask_inputfilter/_input_filter.pyx +++ b/flask_inputfilter/_input_filter.pyx @@ -9,6 +9,8 @@ # cython: optimize.unpack_method_calls=True # cython: infer_types=True +import dataclasses +import inspect import json import logging import sys @@ -228,7 +230,7 @@ cdef class InputFilter: raise ValidationError(errors) self.validated_data = validated_data - return self.serialize() + return self._serialize() cpdef void add_condition(self, BaseCondition condition): """ @@ -240,44 +242,64 @@ cdef class InputFilter: self.conditions.append(condition) cdef void _register_decorator_components(self): - """Register decorator-based components from the current class only.""" - cdef object cls, attr_value, conditions, validators, filters + """Register decorator-based components from the current class and + inheritance chain.""" + cdef object cls, attr_value, conditions, validators, filters, base_cls cdef str attr_name cdef list dir_attrs cdef FieldDescriptor field_desc + cdef set added_conditions, added_global_validators, added_global_filters + cdef object condition_id, validator_id, filter_id + cdef object condition, validator, filter_instance cls = self.__class__ dir_attrs = dir(cls) for attr_name in dir_attrs: - if (attr_name.encode('utf-8')).startswith(b"_"): + if attr_name.startswith("_"): continue + if hasattr(cls, attr_name): + attr_value = getattr(cls, attr_name) + if isinstance(attr_value, FieldDescriptor): + self.fields[attr_name] = FieldModel( + attr_value.required, + attr_value.default, + attr_value.fallback, + attr_value.filters, + attr_value.validators, + attr_value.steps, + attr_value.external_api, + attr_value.copy, + ) - attr_value = getattr(cls, attr_name, None) - if attr_value is not None and isinstance(attr_value, FieldDescriptor): - field_desc = attr_value - self.fields[attr_name] = FieldModel( - field_desc.required, - field_desc._default, - field_desc.fallback, - field_desc.filters, - field_desc.validators, - field_desc.steps, - field_desc.external_api, - field_desc.copy, - ) - - conditions = getattr(cls, "_conditions", None) - if conditions is not None: - self.conditions.extend(conditions) - - validators = getattr(cls, "_global_validators", None) - if validators is not None: - self.global_validators.extend(validators) - - filters = getattr(cls, "_global_filters", None) - if filters is not None: - self.global_filters.extend(filters) + added_conditions = set() + added_global_validators = set() + added_global_filters = set() + + for base_cls in reversed(cls.__mro__): + conditions = getattr(base_cls, "_conditions", None) + if conditions is not None: + for condition in conditions: + condition_id = id(condition) + if condition_id not in added_conditions: + self.conditions.append(condition) + added_conditions.add(condition_id) + + validators = getattr(base_cls, "_global_validators", None) + if validators is not None: + for validator in validators: + validator_id = id(validator) + if validator_id not in added_global_validators: + self.global_validators.append(validator) + added_global_validators.add(validator_id) + + filters = getattr(base_cls, "_global_filters", None) + if filters is not None: + for filter_instance in filters: + filter_id = id(filter_instance) + if filter_id not in added_global_filters: + self.global_filters.append(filter_instance) + added_global_filters.add(filter_id) self.model_class = getattr(cls, "_model", self.model_class) @@ -719,7 +741,7 @@ cdef class InputFilter: """ self.model_class = model_class - cpdef object serialize(self): + cdef object _serialize(self): """ Serialize the validated data. If a model class is set, returns an instance of that class, otherwise returns the @@ -731,7 +753,26 @@ cdef class InputFilter: if self.model_class is None: return self.validated_data - return self.model_class(**self.validated_data) + try: + return self.model_class(**self.validated_data) + except TypeError: + pass + + cdef set field_names = set() + + if dataclasses.is_dataclass(self.model_class): + field_names = {f.name for f in dataclasses.fields(self.model_class)} + elif hasattr(self.model_class, '__fields__'): + field_names = set(self.model_class.__fields__.keys()) + elif hasattr(self.model_class, '__annotations__'): + field_names = set(self.model_class.__annotations__.keys()) + else: + sig = inspect.signature(self.model_class.__init__) + field_names = set(sig.parameters.keys()) - {'self'} + + cdef dict filtered_data = {k: v for k, v in self.validated_data.items() if k in field_names} + + return self.model_class(**filtered_data) cpdef void add_global_validator(self, BaseValidator validator): """ diff --git a/flask_inputfilter/conditions/array_length_equal_condition.py b/flask_inputfilter/conditions/array_length_equal_condition.py index c7f1bf0..7ecd9dd 100644 --- a/flask_inputfilter/conditions/array_length_equal_condition.py +++ b/flask_inputfilter/conditions/array_length_equal_condition.py @@ -28,12 +28,12 @@ class ArrayLengthFilter(InputFilter): list1: list = field(validators=[IsArrayValidator()]) list2: list = field(validators=[IsArrayValidator()]) - _conditions = [ + condition( ArrayLengthEqualCondition( first_array_field='list1', second_array_field='list2' ) - ] + ) """ __slots__ = ("first_array_field", "second_array_field") diff --git a/flask_inputfilter/conditions/base_condition.py b/flask_inputfilter/conditions/base_condition.py index c6ad10e..4a4ba7f 100644 --- a/flask_inputfilter/conditions/base_condition.py +++ b/flask_inputfilter/conditions/base_condition.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings warnings.warn( @@ -7,5 +9,3 @@ DeprecationWarning, stacklevel=2, ) - -from flask_inputfilter.models import BaseCondition diff --git a/flask_inputfilter/conditions/custom_condition.py b/flask_inputfilter/conditions/custom_condition.py index 28145aa..61e93c6 100644 --- a/flask_inputfilter/conditions/custom_condition.py +++ b/flask_inputfilter/conditions/custom_condition.py @@ -33,11 +33,11 @@ def my_custom_condition(data): class CustomFilter(InputFilter): age: int = field(validators=[IsIntegerValidator()]) - _conditions = [ + _condition( CustomCondition( condition=my_custom_condition ) - ] + ) """ __slots__ = ("condition",) diff --git a/flask_inputfilter/conditions/exactly_one_of_condition.py b/flask_inputfilter/conditions/exactly_one_of_condition.py index 716c01c..c577180 100644 --- a/flask_inputfilter/conditions/exactly_one_of_condition.py +++ b/flask_inputfilter/conditions/exactly_one_of_condition.py @@ -26,7 +26,7 @@ class OneFieldFilter(InputFilter): email: str = field() phone: str = field() - _conditions = [ExactlyOneOfCondition(['email', 'phone'])] + condition(ExactlyOneOfCondition(['email', 'phone'])) """ __slots__ = ("fields",) diff --git a/flask_inputfilter/conditions/integer_bigger_than_condition.py b/flask_inputfilter/conditions/integer_bigger_than_condition.py index 01609d9..f1897e4 100644 --- a/flask_inputfilter/conditions/integer_bigger_than_condition.py +++ b/flask_inputfilter/conditions/integer_bigger_than_condition.py @@ -31,12 +31,12 @@ class NumberComparisonFilter(InputFilter): validators=[IsIntegerValidator()] ) - _conditions = [ + condition( IntegerBiggerThanCondition( bigger_field='field_should_be_bigger', smaller_field='field_should_be_smaller' ) - ] + ) """ __slots__ = ("bigger_field", "smaller_field") diff --git a/flask_inputfilter/conditions/not_equal_condition.py b/flask_inputfilter/conditions/not_equal_condition.py index 4f2ebcc..cf9e3c8 100644 --- a/flask_inputfilter/conditions/not_equal_condition.py +++ b/flask_inputfilter/conditions/not_equal_condition.py @@ -27,7 +27,7 @@ class DifferenceFilter(InputFilter): field1: str = field() field2: str = field() - _conditions = [NotEqualCondition('field1', 'field2')] + condition(NotEqualCondition('field1', 'field2')) """ __slots__ = ("first_field", "second_field") diff --git a/flask_inputfilter/conditions/one_of_condition.py b/flask_inputfilter/conditions/one_of_condition.py index f56ff88..27a7032 100644 --- a/flask_inputfilter/conditions/one_of_condition.py +++ b/flask_inputfilter/conditions/one_of_condition.py @@ -27,11 +27,11 @@ class OneFieldRequiredFilter(InputFilter): email: str = field() phone: str = field() - _conditions = [ + condition( OneOfCondition( fields=['email', 'phone'] ) - ] + ) """ __slots__ = ("fields",) diff --git a/flask_inputfilter/conditions/string_longer_than_condition.py b/flask_inputfilter/conditions/string_longer_than_condition.py index 3b8fcb6..09e692f 100644 --- a/flask_inputfilter/conditions/string_longer_than_condition.py +++ b/flask_inputfilter/conditions/string_longer_than_condition.py @@ -26,12 +26,12 @@ class StringLengthFilter(InputFilter): description: str = field() summary: str = field() - _conditions = [ + condition( StringLongerThanCondition( longer_field='description', shorter_field='summary' ) - ] + ) """ __slots__ = ("longer_field", "shorter_field") diff --git a/flask_inputfilter/declarative/__init__.py b/flask_inputfilter/declarative/__init__.py index e5785bc..0ebe4bd 100644 --- a/flask_inputfilter/declarative/__init__.py +++ b/flask_inputfilter/declarative/__init__.py @@ -1,11 +1,19 @@ try: - from ._factory_functions import field from ._field_descriptor import FieldDescriptor except ImportError: - from .factory_functions import field from .field_descriptor import FieldDescriptor +from .condition import condition +from .field import field +from .global_filter import global_filter +from .global_validator import global_validator +from .model import model + __all__ = [ "FieldDescriptor", + "condition", "field", + "global_filter", + "global_validator", + "model", ] diff --git a/flask_inputfilter/declarative/_factory_functions.pxd b/flask_inputfilter/declarative/_factory_functions.pxd deleted file mode 100644 index f89c964..0000000 --- a/flask_inputfilter/declarative/_factory_functions.pxd +++ /dev/null @@ -1,15 +0,0 @@ -# cython: language_level=3 - -from flask_inputfilter.models.cimports cimport BaseFilter, BaseValidator, ExternalApiConfig -from ._field_descriptor cimport FieldDescriptor - -cpdef FieldDescriptor field( - bint required=*, - object default=*, - object fallback=*, - list[BaseFilter] filters=*, - list[BaseValidator] validators=*, - list steps=*, - ExternalApiConfig external_api=*, - str copy=* -) \ No newline at end of file diff --git a/flask_inputfilter/declarative/_factory_functions.pyx b/flask_inputfilter/declarative/_factory_functions.pyx deleted file mode 100644 index ec05665..0000000 --- a/flask_inputfilter/declarative/_factory_functions.pyx +++ /dev/null @@ -1,64 +0,0 @@ -# cython: language_level=3 - -from ._field_descriptor cimport FieldDescriptor -from flask_inputfilter.models.cimports cimport BaseFilter, BaseValidator, ExternalApiConfig - -cpdef FieldDescriptor field( - bint required = False, - object default = None, - object fallback = None, - list[BaseFilter] filters = None, - list[BaseValidator] validators = None, - list steps = None, - ExternalApiConfig external_api = None, - str copy = None, -): - """ - Create a field descriptor for declarative field definition. - - This function creates a FieldDescriptor that can be used as a class - attribute to define input filter fields declaratively. - - **Parameters:** - - - **required** (*bool*): Whether the field is required. Default: False. - - **default** (*Any*): The default value of the field. Default: None. - - **fallback** (*Any*): The fallback value of the field, if - validations fail or field is None, although it is required. Default: None. - - **filters** (*Optional[list[BaseFilter]]*): The filters to apply to - the field value. Default: None. - - **validators** (*Optional[list[BaseValidator]]*): The validators to - apply to the field value. Default: None. - - **steps** (*Optional[list[Union[BaseFilter, BaseValidator]]]*): Allows - to apply multiple filters and validators in a specific order. Default: None. - - **external_api** (*Optional[ExternalApiConfig]*): Configuration for an - external API call. Default: None. - - **copy** (*Optional[str]*): The name of the field to copy the value - from. Default: None. - - **Returns:** - - A field descriptor configured with the given parameters. - - **Example:** - - .. code-block:: python - - from flask_inputfilter import InputFilter - from flask_inputfilter.declarative import field - from flask_inputfilter.validators import IsStringValidator - - class UserInputFilter(InputFilter): - name: str = field(required=True, validators=[IsStringValidator()]) - age: int = field(required=True, default=18) - """ - return FieldDescriptor( - required=required, - default=default, - fallback=fallback, - filters=filters, - validators=validators, - steps=steps, - external_api=external_api, - copy=copy, - ) diff --git a/flask_inputfilter/declarative/_utils.py b/flask_inputfilter/declarative/_utils.py new file mode 100644 index 0000000..790dd31 --- /dev/null +++ b/flask_inputfilter/declarative/_utils.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import inspect +from typing import Any + + +def register_class_attribute(attribute_name: str, value: Any) -> None: + """ + Register an attribute on the calling class during class definition. + + This utility function uses frame inspection to access the class being + defined and set an attribute on it. This is used by the declarative + factory functions to register components during class definition. + + Args: + attribute_name: The name of the attribute to set on the class + value: The value to set for the attribute + """ + frame = inspect.currentframe() + if frame and frame.f_back and frame.f_back.f_back: + caller_locals = frame.f_back.f_back.f_locals + if "__module__" in caller_locals and "__qualname__" in caller_locals: + caller_locals[attribute_name] = value + + +def append_to_class_list(list_name: str, value: Any) -> None: + """ + Append a value to a list attribute on the calling class during class + definition. + + This utility function uses frame inspection to access the class being + defined and append a value to a list attribute. If the list doesn't exist, + it creates it first. + + Args: + list_name: The name of the list attribute on the class + value: The value to append to the list + """ + frame = inspect.currentframe() + if frame and frame.f_back and frame.f_back.f_back: + caller_locals = frame.f_back.f_back.f_locals + if "__module__" in caller_locals and "__qualname__" in caller_locals: + if list_name not in caller_locals: + caller_locals[list_name] = [] + caller_locals[list_name].append(value) diff --git a/flask_inputfilter/declarative/cimports.pxd b/flask_inputfilter/declarative/cimports.pxd index 7cd8b1a..2c5fb38 100644 --- a/flask_inputfilter/declarative/cimports.pxd +++ b/flask_inputfilter/declarative/cimports.pxd @@ -1,2 +1 @@ -from ._factory_functions cimport field from ._field_descriptor cimport FieldDescriptor diff --git a/flask_inputfilter/declarative/condition.py b/flask_inputfilter/declarative/condition.py new file mode 100644 index 0000000..1c73961 --- /dev/null +++ b/flask_inputfilter/declarative/condition.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._utils import append_to_class_list + +if TYPE_CHECKING: + from flask_inputfilter.models import BaseCondition + + +def condition(*condition_instances: BaseCondition) -> None: + """ + Register one or more conditions for declarative condition definition. + + This function registers conditions directly in the class definition + without requiring variable assignment or __init__ methods. + + **Parameters:** + + - **condition_instances** (*BaseCondition*): One or more condition + instances to register. + + **Examples:** + + .. code-block:: python + + class RegistrationInputFilter(InputFilter): + password: str = field( + required=True, validators=[IsStringValidator()] + ) + password_confirmation: str = field( + required=True, validators=[IsStringValidator()] + ) + + # Single condition + condition(EqualCondition('password', 'password_confirmation')) + + # Multiple conditions at once + condition( + RequiredCondition('password'), + LengthCondition('password', min_length=8) + ) + """ + for condition_instance in condition_instances: + append_to_class_list("_conditions", condition_instance) diff --git a/flask_inputfilter/declarative/factory_functions.pyi b/flask_inputfilter/declarative/factory_functions.pyi deleted file mode 100644 index 12f15de..0000000 --- a/flask_inputfilter/declarative/factory_functions.pyi +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Optional, Union - -if TYPE_CHECKING: - from flask_inputfilter.models import ( - BaseFilter, - BaseValidator, - ExternalApiConfig, - ) - -from .field_descriptor import FieldDescriptor - -def field( - required: bool = False, - default: Any = None, - fallback: Any = None, - filters: Optional[list[BaseFilter]] = None, - validators: Optional[list[BaseValidator]] = None, - steps: Optional[list[Union[BaseFilter, BaseValidator]]] = None, - external_api: Optional[ExternalApiConfig] = None, - copy: Optional[str] = None, -) -> FieldDescriptor: - """ - Create a field descriptor for declarative field definition. - - This function creates a FieldDescriptor that can be used as a class - attribute to define input filter fields declaratively. - - **Parameters:** - - - **required** (*bool*): Whether the field is required. Default: False. - - **default** (*Any*): The default value of the field. Default: None. - - **fallback** (*Any*): The fallback value of the field, if - validations fail or field is None, although it is required. - Default: None. - - **filters** (*Optional[list[BaseFilter]]*): The filters to apply to - the field value. Default: None. - - **validators** (*Optional[list[BaseValidator]]*): The validators to - apply to the field value. Default: None. - - **steps** (*Optional[list[Union[BaseFilter, BaseValidator]]]*): Allows - to apply multiple filters and validators in a specific order. - Default: None. - - **external_api** (*Optional[ExternalApiConfig]*): Configuration for an - external API call. Default: None. - - **copy** (*Optional[str]*): The name of the field to copy the value - from. Default: None. - - **Returns:** - - A field descriptor configured with the given parameters. - - **Example:** - - .. code-block:: python - - from flask_inputfilter import InputFilter - from flask_inputfilter.declarative import field - from flask_inputfilter.validators import IsStringValidator - - class UserInputFilter(InputFilter): - name: str = field(required=True, validators=[IsStringValidator()]) - age: int = field(required=True, default=18) - """ diff --git a/flask_inputfilter/declarative/factory_functions.py b/flask_inputfilter/declarative/field.py similarity index 97% rename from flask_inputfilter/declarative/factory_functions.py rename to flask_inputfilter/declarative/field.py index bd5b40e..a4c086b 100644 --- a/flask_inputfilter/declarative/factory_functions.py +++ b/flask_inputfilter/declarative/field.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union -from .field_descriptor import FieldDescriptor +from flask_inputfilter.declarative import FieldDescriptor if TYPE_CHECKING: from flask_inputfilter.models import ( @@ -13,6 +13,7 @@ def field( + *, required: bool = False, default: Any = None, fallback: Any = None, diff --git a/flask_inputfilter/declarative/global_filter.py b/flask_inputfilter/declarative/global_filter.py new file mode 100644 index 0000000..6edc485 --- /dev/null +++ b/flask_inputfilter/declarative/global_filter.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._utils import append_to_class_list + +if TYPE_CHECKING: + from flask_inputfilter.models import BaseFilter + + +def global_filter(*filter_instances: BaseFilter) -> None: + """ + Register one or more global filters for declarative definition. + + This function registers global filters directly in the class definition + without requiring variable assignment or __init__ methods. + + **Parameters:** + + - **filter_instances** (*BaseFilter*): One or more filter instances to + register globally. + + **Examples:** + + .. code-block:: python + + class MyInputFilter(InputFilter): + name: str = field(required=True, validators=[IsStringValidator()]) + email: str = field(required=True) + + # Single global filter + global_filter(StringTrimFilter()) + + # Multiple global filters at once + global_filter( + StringTrimFilter(), + ToLowerFilter(), + RemoveWhitespaceFilter() + ) + """ + for filter_instance in filter_instances: + append_to_class_list("_global_filters", filter_instance) diff --git a/flask_inputfilter/declarative/global_validator.py b/flask_inputfilter/declarative/global_validator.py new file mode 100644 index 0000000..b3e367a --- /dev/null +++ b/flask_inputfilter/declarative/global_validator.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._utils import append_to_class_list + +if TYPE_CHECKING: + from flask_inputfilter.models import BaseValidator + + +def global_validator(*validator_instances: BaseValidator) -> None: + """ + Register one or more global validators for declarative definition. + + This function registers global validators directly in the class definition + without requiring variable assignment or __init__ methods. + + **Parameters:** + + - **validator_instances** (*BaseValidator*): One or more validator + instances to register globally. + + **Examples:** + + .. code-block:: python + + class MyInputFilter(InputFilter): + name: str = field(required=True) + email: str = field(required=True) + + # Single global validator + global_validator(IsStringValidator()) + + # Multiple global validators at once + global_validator( + IsStringValidator(), + LengthValidator(min_length=1), + NotEmptyValidator() + ) + """ + for validator_instance in validator_instances: + append_to_class_list("_global_validators", validator_instance) diff --git a/flask_inputfilter/declarative/model.py b/flask_inputfilter/declarative/model.py new file mode 100644 index 0000000..7018ce0 --- /dev/null +++ b/flask_inputfilter/declarative/model.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from ._utils import register_class_attribute + + +def model(model_class: type) -> None: + """ + Set the model class for declarative definition. + + This function sets the model class directly in the class definition + without requiring variable assignment or __init__ methods. + + **Parameters:** + + - **model_class** (*type*): The model class to use for serialization. + + **Example:** + + .. code-block:: python + + from dataclasses import dataclass + + @dataclass + class UserModel: + name: str + email: str + + class UserInputFilter(InputFilter): + name: str = field(required=True, validators=[IsStringValidator()]) + email: str = field(required=True, validators=[EmailValidator()]) + + model(UserModel) + """ + register_class_attribute("_model", model_class) diff --git a/flask_inputfilter/filters/base_filter.py b/flask_inputfilter/filters/base_filter.py index 45393e7..e4800d9 100644 --- a/flask_inputfilter/filters/base_filter.py +++ b/flask_inputfilter/filters/base_filter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings warnings.warn( @@ -7,5 +9,3 @@ DeprecationWarning, stacklevel=2, ) - -from flask_inputfilter.models import BaseFilter diff --git a/flask_inputfilter/input_filter.py b/flask_inputfilter/input_filter.py index 23e69ac..6557830 100644 --- a/flask_inputfilter/input_filter.py +++ b/flask_inputfilter/input_filter.py @@ -1,5 +1,7 @@ from __future__ import annotations +import dataclasses +import inspect import json import logging import sys @@ -200,7 +202,7 @@ def validate_data( raise ValidationError(errors) self.validated_data = validated_data - return self.serialize() + return self._serialize() def add_condition(self, condition: BaseCondition) -> None: """ @@ -212,10 +214,12 @@ def add_condition(self, condition: BaseCondition) -> None: self.conditions.append(condition) def _register_decorator_components(self) -> None: - """Register decorator-based components from the current class only.""" + """Register decorator-based components from the current class and + inheritance chain.""" cls = self.__class__ + dir_attrs = dir(cls) - for attr_name in dir(cls): + for attr_name in dir_attrs: if attr_name.startswith("_"): continue if hasattr(cls, attr_name): @@ -232,17 +236,34 @@ def _register_decorator_components(self) -> None: attr_value.copy, ) - if hasattr(cls, "_conditions"): - conditions = cls._conditions - self.conditions.extend(conditions) - - if hasattr(cls, "_global_validators"): - validators = cls._global_validators - self.global_validators.extend(validators) - - if hasattr(cls, "_global_filters"): - filters = cls._global_filters - self.global_filters.extend(filters) + added_conditions = set() + added_global_validators = set() + added_global_filters = set() + + for base_cls in reversed(cls.__mro__): + conditions = getattr(base_cls, "_conditions", None) + if conditions is not None: + for condition in conditions: + condition_id = id(condition) + if condition_id not in added_conditions: + self.conditions.append(condition) + added_conditions.add(condition_id) + + validators = getattr(base_cls, "_global_validators", None) + if validators is not None: + for validator in validators: + validator_id = id(validator) + if validator_id not in added_global_validators: + self.global_validators.append(validator) + added_global_validators.add(validator_id) + + filters = getattr(base_cls, "_global_filters", None) + if filters is not None: + for filter_instance in filters: + filter_id = id(filter_instance) + if filter_id not in added_global_filters: + self.global_filters.append(filter_instance) + added_global_filters.add(filter_id) if hasattr(cls, "_model"): self.model_class = cls._model @@ -683,7 +704,7 @@ def set_model(self, model_class: Type[T]) -> None: """ self.model_class = model_class - def serialize(self) -> Union[dict[str, Any], T]: + def _serialize(self) -> Union[dict[str, Any], T]: """ Serialize the validated data. If a model class is set, returns an instance of that class, otherwise returns the raw validated data. @@ -694,7 +715,28 @@ def serialize(self) -> Union[dict[str, Any], T]: if self.model_class is None: return self.validated_data - return self.model_class(**self.validated_data) + try: + return self.model_class(**self.validated_data) + except TypeError: + pass + + if dataclasses.is_dataclass(self.model_class): + field_names = { + f.name for f in dataclasses.fields(self.model_class) + } + elif hasattr(self.model_class, "__fields__"): + field_names = set(self.model_class.__fields__.keys()) + elif hasattr(self.model_class, "__annotations__"): + field_names = set(self.model_class.__annotations__.keys()) + else: + sig = inspect.signature(self.model_class.__init__) + field_names = set(sig.parameters.keys()) - {"self"} + + filtered_data = { + k: v for k, v in self.validated_data.items() if k in field_names + } + + return self.model_class(**filtered_data) def add_global_validator(self, validator: BaseValidator) -> None: """ diff --git a/flask_inputfilter/input_filter.pyi b/flask_inputfilter/input_filter.pyi index 8fd488f..10f2ce9 100644 --- a/flask_inputfilter/input_filter.pyi +++ b/flask_inputfilter/input_filter.pyi @@ -87,6 +87,5 @@ class InputFilter: def clear(self) -> None: ... def merge(self, other: InputFilter) -> None: ... def set_model(self, model_class: Type[T]) -> None: ... - def serialize(self) -> Union[dict[str, Any], T]: ... def add_global_validator(self, validator: BaseValidator) -> None: ... def get_global_validators(self) -> list[BaseValidator]: ... diff --git a/flask_inputfilter/templates/registration.py b/flask_inputfilter/templates/registration.py index 8e4e4b2..016d118 100644 --- a/flask_inputfilter/templates/registration.py +++ b/flask_inputfilter/templates/registration.py @@ -2,7 +2,7 @@ from flask_inputfilter import InputFilter from flask_inputfilter.conditions import EqualCondition -from flask_inputfilter.declarative import field +from flask_inputfilter.declarative import condition, field from flask_inputfilter.enums import RegexEnum from flask_inputfilter.filters import ( StringTrimFilter, @@ -45,4 +45,4 @@ class RegistrationInputFilter(InputFilter): validators=[IsBooleanValidator()], ) - _conditions = [EqualCondition("password", "password_confirmation")] + condition(EqualCondition("password", "password_confirmation")) diff --git a/flask_inputfilter/validators/base_validator.py b/flask_inputfilter/validators/base_validator.py index 398315b..b6590e1 100644 --- a/flask_inputfilter/validators/base_validator.py +++ b/flask_inputfilter/validators/base_validator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings warnings.warn( @@ -7,5 +9,3 @@ DeprecationWarning, stacklevel=2, ) - -from flask_inputfilter.models import BaseValidator diff --git a/flask_inputfilter/validators/is_dataclass_validator.py b/flask_inputfilter/validators/is_dataclass_validator.py index 67434d3..e042875 100644 --- a/flask_inputfilter/validators/is_dataclass_validator.py +++ b/flask_inputfilter/validators/is_dataclass_validator.py @@ -9,30 +9,32 @@ T = TypeVar("T") -# TODO: Replace with typing.get_origin when Python 3.7 support is dropped. -def get_origin(tp: Any) -> Optional[Type[Any]]: - """ - Get the unsubscripted version of a type. - - This supports typing types like list, dict, etc. and their - typing_extensions equivalents. - """ - if isinstance(tp, _GenericAlias): - return tp.__origin__ - return None - - -# TODO: Replace with typing.get_args when Python 3.7 support is dropped. -def get_args(tp: Any) -> tuple[Any, ...]: - """ - Get type arguments with all substitutions performed. - - For unions, basic types, and special typing forms, returns the type - arguments. For example, for list[int] returns (int,). - """ - if isinstance(tp, _GenericAlias): - return tp.__args__ - return () +# Compatibility functions for Python 3.7 support +try: + from typing import get_args, get_origin +except ImportError: + # Fallback implementations for Python 3.7 + def get_origin(tp: Any) -> Optional[Type[Any]]: + """ + Get the unsubscripted version of a type. + + This supports typing types like list, dict, etc. and their + typing_extensions equivalents. + """ + if isinstance(tp, _GenericAlias): + return tp.__origin__ + return None + + def get_args(tp: Any) -> tuple[Any, ...]: + """ + Get type arguments with all substitutions performed. + + For unions, basic types, and special typing forms, returns the type + arguments. For example, for list[int] returns (int,). + """ + if isinstance(tp, _GenericAlias): + return tp.__args__ + return () class IsDataclassValidator(BaseValidator): diff --git a/flask_inputfilter/validators/length_validator.py b/flask_inputfilter/validators/length_validator.py index 4930917..d6765e4 100644 --- a/flask_inputfilter/validators/length_validator.py +++ b/flask_inputfilter/validators/length_validator.py @@ -1,19 +1,11 @@ from __future__ import annotations -from enum import Enum from typing import Any, Optional from flask_inputfilter.exceptions import ValidationError from flask_inputfilter.models import BaseValidator -class LengthEnum(Enum): - """Enum that defines the possible length types.""" - - LEAST = "least" - MOST = "most" - - class LengthValidator(BaseValidator): """ Validates the length of a string, ensuring it falls within a specified @@ -54,7 +46,7 @@ def __init__( self.error_message = error_message def validate(self, value: Any) -> None: - if (self.max_length is not None and len(value) < self.min_length) or ( + if (self.min_length is not None and len(value) < self.min_length) or ( self.max_length is not None and len(value) > self.max_length ): raise ValidationError( diff --git a/pyproject.toml b/pyproject.toml index 435ad68..a7f09c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "flask_inputfilter" -version = "0.7.1" +version = "0.7.2" description = "A library to easily filter and validate input data in Flask applications" readme = "README.md" keywords = [ diff --git a/tests/declarative/__init__.py b/tests/declarative/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/declarative/test_condition_decorator.py b/tests/declarative/test_condition_decorator.py new file mode 100644 index 0000000..6b11aab --- /dev/null +++ b/tests/declarative/test_condition_decorator.py @@ -0,0 +1,560 @@ +import unittest + +from flask_inputfilter import InputFilter +from flask_inputfilter.declarative import condition, field +from flask_inputfilter.conditions import ( + EqualCondition, ExactlyOneOfCondition, NotEqualCondition, RequiredIfCondition, + ArrayLengthEqualCondition, StringLongerThanCondition, OneOfCondition, + TemporalOrderCondition, NOfCondition, CustomCondition +) +from flask_inputfilter.exceptions import ValidationError +from flask_inputfilter.filters import StringTrimFilter, ToLowerFilter +from flask_inputfilter.validators import IsStringValidator, LengthValidator, IsDateValidator + + +class TestConditionDecorator(unittest.TestCase): + + def test_single_condition_decorator(self): + + class TestInputFilter(InputFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + + condition(EqualCondition("password", "password_confirmation")) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + self.assertIsInstance(conditions[0], EqualCondition) + + valid_data = { + 'password': 'test123', + 'password_confirmation': 'test123' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['password'], 'test123') + self.assertEqual(validated_data['password_confirmation'], 'test123') + + invalid_data = { + 'password': 'test123', + 'password_confirmation': 'different' + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + def test_multiple_condition_decorators(self): + + class TestInputFilter(InputFilter): + field_a = field(required=False) + field_b = field(required=False) + field_c = field(required=False) + password = field(required=False, validators=[IsStringValidator()]) + password_confirmation = field(required=False, validators=[IsStringValidator()]) + + condition(ExactlyOneOfCondition(['field_a', 'field_b', 'field_c'])) + condition(EqualCondition("password", "password_confirmation")) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 2) + + condition_types = [type(c) for c in conditions] + self.assertIn(ExactlyOneOfCondition, condition_types) + self.assertIn(EqualCondition, condition_types) + + valid_data = { + 'field_a': 'value1', + 'password': 'test123', + 'password_confirmation': 'test123' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['field_a'], 'value1') + + invalid_data1 = { + 'field_a': 'value1', + 'field_b': 'value2', + 'password': 'test123', + 'password_confirmation': 'test123' + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data1) + + invalid_data2 = { + 'field_a': 'value1', + 'password': 'test123', + 'password_confirmation': 'different' + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data2) + + def test_condition_inheritance(self): + + class BaseInputFilter(InputFilter): + field_a = field(required=False) + field_b = field(required=False) + + condition(ExactlyOneOfCondition(['field_a', 'field_b'])) + + class ChildInputFilter(BaseInputFilter): + password = field(required=False, validators=[IsStringValidator()]) + password_confirmation = field(required=False, validators=[IsStringValidator()]) + + condition(EqualCondition("password", "password_confirmation")) + + child_filter = ChildInputFilter() + + conditions = child_filter.get_conditions() + self.assertEqual(len(conditions), 2) + + condition_types = [type(c) for c in conditions] + self.assertIn(ExactlyOneOfCondition, condition_types) + self.assertIn(EqualCondition, condition_types) + + valid_data = { + 'field_a': 'value1', + 'password': 'test123', + 'password_confirmation': 'test123' + } + validated_data = child_filter.validate_data(valid_data) + self.assertEqual(validated_data['field_a'], 'value1') + + + def test_backward_compatibility_with_conditions_list(self): + + class TestInputFilter(InputFilter): + password = field(required=False, validators=[IsStringValidator()]) + password_confirmation = field(required=False, validators=[IsStringValidator()]) + + condition(EqualCondition("password", "password_confirmation")) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + + condition_types = [type(c) for c in conditions] + self.assertIn(EqualCondition, condition_types) + + def test_backward_compatibility_with_init_method(self): + + class TestInputFilter(InputFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + field_a = field(required=False) + field_b = field(required=False) + + def __init__(self): + super().__init__() + + self.add_condition(EqualCondition("password", "password_confirmation")) + self.add_condition(ExactlyOneOfCondition(['field_a', 'field_b'])) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 2) + + condition_types = [type(c) for c in conditions] + self.assertIn(EqualCondition, condition_types) + self.assertIn(ExactlyOneOfCondition, condition_types) + + valid_data = { + 'password': 'test123', + 'password_confirmation': 'test123', + 'field_a': 'value1' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['password'], 'test123') + + def test_all_three_condition_styles_together(self): + + class TestInputFilter(InputFilter): + password = field(required=False, validators=[IsStringValidator()]) + password_confirmation = field(required=False, validators=[IsStringValidator()]) + field_c = field(required=False) + + condition(EqualCondition("password", "password_confirmation")) + + def __init__(self): + super().__init__() + self.add_condition(NotEqualCondition('field_c', 'password')) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 2) + + condition_types = [type(c).__name__ for c in conditions] + self.assertIn('EqualCondition', condition_types) + self.assertIn('NotEqualCondition', condition_types) + + def test_no_init_required(self): + + class TestInputFilter(InputFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + + condition(EqualCondition("password", "password_confirmation")) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + self.assertIsInstance(conditions[0], EqualCondition) + + valid_data = { + 'password': 'test123', + 'password_confirmation': 'test123' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['password'], 'test123') + + def test_anonymous_condition_calls(self): + + class TestInputFilter(InputFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + field_a = field(required=False) + field_b = field(required=False) + + condition(EqualCondition("password", "password_confirmation")) + condition(ExactlyOneOfCondition(['field_a', 'field_b'])) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 2) + + condition_types = [type(c) for c in conditions] + self.assertIn(EqualCondition, condition_types) + self.assertIn(ExactlyOneOfCondition, condition_types) + + valid_data = { + 'password': 'test123', + 'password_confirmation': 'test123', + 'field_a': 'value1' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['password'], 'test123') + + def test_mixed_named_and_anonymous_conditions(self): + + class TestInputFilter(InputFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + field_a = field(required=False) + field_b = field(required=False) + + condition(EqualCondition("password", "password_confirmation")) + condition(ExactlyOneOfCondition(['field_a', 'field_b'])) + + filter_instance = TestInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 2) + + condition_types = [type(c) for c in conditions] + self.assertIn(EqualCondition, condition_types) + self.assertIn(ExactlyOneOfCondition, condition_types) + + def test_ruf012_linting_problem_solved(self): + + class ProblematicInputFilter(InputFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + + condition(EqualCondition("password", "password_confirmation")) + + filter_instance = ProblematicInputFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + self.assertIsInstance(conditions[0], EqualCondition) + + valid_data = { + 'password': 'test123', + 'password_confirmation': 'test123' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['password'], 'test123') + + invalid_data = { + 'password': 'test123', + 'password_confirmation': 'different' + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + self.assertTrue(hasattr(ProblematicInputFilter, '_conditions')) + + def test_empty_conditions_handling(self): + """Test behavior when no conditions are defined.""" + + class EmptyConditionsFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + filter_instance = EmptyConditionsFilter() + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 0) + + # Should still validate normally + test_data = {'name': 'test'} + validated_data = filter_instance.validate_data(test_data) + self.assertEqual(validated_data['name'], 'test') + + def test_duplicate_conditions(self): + """Test adding the same condition multiple times.""" + + class DuplicateConditionsFilter(InputFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + + condition(EqualCondition("password", "password_confirmation")) + condition(EqualCondition("password", "password_confirmation")) + + filter_instance = DuplicateConditionsFilter() + conditions = filter_instance.get_conditions() + + # Both conditions should be present (framework doesn't deduplicate) + self.assertEqual(len(conditions), 2) + self.assertTrue(all(isinstance(c, EqualCondition) for c in conditions)) + + def test_complex_condition_combinations(self): + """Test with multiple different condition types.""" + + class ComplexConditionsFilter(InputFilter): + tags = field(required=False) + description = field(required=False, validators=[IsStringValidator()]) + category = field(required=False) + start_date = field(required=False, validators=[IsDateValidator()]) + end_date = field(required=False, validators=[IsDateValidator()]) + + condition(ArrayLengthEqualCondition("tags", 3)) + condition(StringLongerThanCondition("description", 10)) + condition(OneOfCondition(["tech", "business", "personal"])) + condition(RequiredIfCondition("start_date", None, "end_date")) + + filter_instance = ComplexConditionsFilter() + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 4) + + condition_types = [type(c).__name__ for c in conditions] + self.assertIn('ArrayLengthEqualCondition', condition_types) + self.assertIn('StringLongerThanCondition', condition_types) + self.assertIn('OneOfCondition', condition_types) + self.assertIn('RequiredIfCondition', condition_types) + + def test_nested_inheritance_conditions(self): + """Test conditions with multiple inheritance levels.""" + + class BaseFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + condition(StringLongerThanCondition("name", 2)) + + class MiddleFilter(BaseFilter): + email = field(required=True, validators=[IsStringValidator()]) + + condition(StringLongerThanCondition("email", 5)) + + class FinalFilter(MiddleFilter): + password = field(required=True, validators=[IsStringValidator()]) + password_confirmation = field(required=True, validators=[IsStringValidator()]) + + condition(EqualCondition("password", "password_confirmation")) + + filter_instance = FinalFilter() + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 3) + + condition_types = [type(c).__name__ for c in conditions] + self.assertIn('StringLongerThanCondition', condition_types) + self.assertIn('EqualCondition', condition_types) + # Should have 2 StringLongerThanCondition instances + self.assertEqual(sum(1 for ct in condition_types if ct == 'StringLongerThanCondition'), 2) + + def test_diamond_inheritance_pattern(self): + """Test proper condition resolution with diamond inheritance.""" + + class BaseFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + condition(StringLongerThanCondition("name", 1)) + + class LeftFilter(BaseFilter): + left_field = field(required=False) + + condition(RequiredIfCondition("name", None, "left_field")) + + class RightFilter(BaseFilter): + right_field = field(required=False) + + condition(RequiredIfCondition("name", None, "right_field")) + + class DiamondFilter(LeftFilter, RightFilter): + final_field = field(required=False) + + condition(StringLongerThanCondition("final_field", 0)) + + filter_instance = DiamondFilter() + conditions = filter_instance.get_conditions() + + # Should have all conditions from inheritance chain + self.assertGreaterEqual(len(conditions), 3) + + condition_types = [type(c).__name__ for c in conditions] + self.assertIn('StringLongerThanCondition', condition_types) + self.assertIn('RequiredIfCondition', condition_types) + + def test_missing_field_references_in_conditions(self): + """Test conditions referencing non-existent fields.""" + + class MissingFieldFilter(InputFilter): + existing_field = field(required=True, validators=[IsStringValidator()]) + + # This condition references a field that doesn't exist + condition(EqualCondition("existing_field", "non_existent_field")) + + filter_instance = MissingFieldFilter() + + # The condition should be created without error during class definition + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + + # But validation should fail when the condition is checked + test_data = {'existing_field': 'test'} + with self.assertRaises(ValidationError): + filter_instance.validate_data(test_data) + + def test_custom_condition_usage(self): + """Test using custom conditions with the decorator.""" + + def custom_validation_logic(data): + # Custom logic: password must be different from username + if 'username' in data and 'password' in data: + return data['username'] != data['password'] + return True + + class CustomConditionFilter(InputFilter): + username = field(required=True, validators=[IsStringValidator()]) + password = field(required=True, validators=[IsStringValidator()]) + + condition(CustomCondition(custom_validation_logic)) + + filter_instance = CustomConditionFilter() + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + self.assertIsInstance(conditions[0], CustomCondition) + + # Valid case + valid_data = {'username': 'john', 'password': 'secret123'} + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['username'], 'john') + + # Invalid case - same username and password + invalid_data = {'username': 'john', 'password': 'john'} + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + def test_condition_with_no_fields_defined(self): + """Test edge case where condition is defined but no fields exist.""" + + class NoFieldsFilter(InputFilter): + # Define a condition but no fields + condition(EqualCondition("field1", "field2")) + + filter_instance = NoFieldsFilter() + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + + # Validation might not fail immediately for missing fields in conditions + # The condition framework may handle this gracefully + try: + filter_instance.validate_data({}) + # If it doesn't raise an error, that's also acceptable behavior + except ValidationError: + # Expected if condition validation fails + pass + + def test_conditions_with_temporal_order(self): + """Test temporal order conditions for date fields.""" + + class TemporalFilter(InputFilter): + start_date = field(required=True, validators=[IsDateValidator()]) + end_date = field(required=True, validators=[IsDateValidator()]) + + condition(TemporalOrderCondition("start_date", "end_date")) + + filter_instance = TemporalFilter() + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 1) + self.assertIsInstance(conditions[0], TemporalOrderCondition) + + # Valid case - start before end + from datetime import date + valid_data = { + 'start_date': date(2023, 1, 1), + 'end_date': date(2023, 12, 31) + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['start_date'], date(2023, 1, 1)) + + # Invalid case - start after end + invalid_data = { + 'start_date': date(2023, 12, 31), + 'end_date': date(2023, 1, 1) + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + def test_multiple_conditions_at_once(self): + """Test registering multiple conditions in a single call.""" + + class MultiConditionFilter(InputFilter): + password = field(required=True) + password_confirmation = field(required=True) + email = field(required=True) + email_confirmation = field(required=True) + + condition( + EqualCondition('password', 'password_confirmation'), + EqualCondition('email', 'email_confirmation') + ) + + filter_instance = MultiConditionFilter() + + conditions = filter_instance.get_conditions() + self.assertEqual(len(conditions), 2) + + condition_types = [type(c) for c in conditions] + self.assertEqual(condition_types.count(EqualCondition), 2) + + validated_data = filter_instance.validate_data({ + 'password': 'test123', + 'password_confirmation': 'test123', + 'email': 'test@example.com', + 'email_confirmation': 'test@example.com' + }) + self.assertEqual(validated_data['password'], 'test123') + + with self.assertRaises(ValidationError): + filter_instance.validate_data({ + 'password': 'test123', + 'password_confirmation': 'different', + 'email': 'test@example.com', + 'email_confirmation': 'test@example.com' + }) + + with self.assertRaises(ValidationError): + filter_instance.validate_data({ + 'password': 'test123', + 'password_confirmation': 'test123', + 'email': 'test@example.com', + 'email_confirmation': 'different@example.com' + }) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/declarative/test_global_filter_decorator.py b/tests/declarative/test_global_filter_decorator.py new file mode 100644 index 0000000..4e1d651 --- /dev/null +++ b/tests/declarative/test_global_filter_decorator.py @@ -0,0 +1,367 @@ +import unittest +from flask_inputfilter import InputFilter +from flask_inputfilter.declarative import field, global_filter +from flask_inputfilter.filters import ( + StringTrimFilter, ToLowerFilter, ToUpperFilter, ToPascalCaseFilter, + ToSnakeCaseFilter, ToIntegerFilter, ToFloatFilter, WhitespaceCollapseFilter +) +from flask_inputfilter.validators import IsStringValidator, IsIntegerValidator +from flask_inputfilter.exceptions import ValidationError + + +class TestGlobalFilterDecorator(unittest.TestCase): + + def test_global_filter_decorator(self): + + class TestInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[IsStringValidator()]) + + global_filter(StringTrimFilter()) + global_filter(ToLowerFilter()) + + filter_instance = TestInputFilter() + + global_filters = filter_instance.get_global_filters() + self.assertEqual(len(global_filters), 2) + + filter_types = [type(f) for f in global_filters] + self.assertIn(StringTrimFilter, filter_types) + self.assertIn(ToLowerFilter, filter_types) + + test_data = {'name': ' JOHN ', 'email': ' TEST@EXAMPLE.COM '} + validated_data = filter_instance.validate_data(test_data) + + self.assertEqual(validated_data['name'], 'john') + self.assertEqual(validated_data['email'], 'test@example.com') + + def test_global_filter_inheritance(self): + + class BaseInputFilter(InputFilter): + name = field(required=True) + + global_filter(StringTrimFilter()) + + class ChildInputFilter(BaseInputFilter): + email = field(required=True) + + global_filter(ToLowerFilter()) + + child_filter = ChildInputFilter() + + global_filters = child_filter.get_global_filters() + self.assertEqual(len(global_filters), 2) + + filter_types = [type(f) for f in global_filters] + self.assertIn(StringTrimFilter, filter_types) + self.assertIn(ToLowerFilter, filter_types) + + test_data = {'name': ' JOHN ', 'email': ' TEST@EXAMPLE.COM '} + validated_data = child_filter.validate_data(test_data) + + self.assertEqual(validated_data['name'], 'john') + self.assertEqual(validated_data['email'], 'test@example.com') + + def test_ruf012_solved_for_global_filters(self): + + class ProblematicInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[IsStringValidator()]) + + global_filter(StringTrimFilter()) + + filter_instance = ProblematicInputFilter() + + global_filters = filter_instance.get_global_filters() + self.assertEqual(len(global_filters), 1) + + test_data = {'name': ' John ', 'email': ' test@example.com '} + validated_data = filter_instance.validate_data(test_data) + + self.assertEqual(validated_data['name'], 'John') + self.assertEqual(validated_data['email'], 'test@example.com') + + self.assertTrue(hasattr(ProblematicInputFilter, '_global_filters')) + + def test_empty_global_filters_behavior(self): + """Test behavior when no global filters are defined.""" + + class NoGlobalFiltersFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + filter_instance = NoGlobalFiltersFilter() + global_filters = filter_instance.get_global_filters() + self.assertEqual(len(global_filters), 0) + + # Should pass data through unchanged + test_data = {'name': ' TEST '} + validated_data = filter_instance.validate_data(test_data) + self.assertEqual(validated_data['name'], ' TEST ') + + def test_filter_ordering(self): + """Test that filters are applied in declaration order.""" + + class OrderedFiltersFilter(InputFilter): + text = field(required=True, validators=[IsStringValidator()]) + + # Order: first trim, then convert to lower + global_filter(StringTrimFilter()) + global_filter(ToLowerFilter()) + + filter_instance = OrderedFiltersFilter() + global_filters = filter_instance.get_global_filters() + self.assertEqual(len(global_filters), 2) + + # StringTrim should be first, ToLower second + self.assertIsInstance(global_filters[0], StringTrimFilter) + self.assertIsInstance(global_filters[1], ToLowerFilter) + + test_data = {'text': ' HELLO WORLD '} + validated_data = filter_instance.validate_data(test_data) + # Should be trimmed first, then lowercased + self.assertEqual(validated_data['text'], 'hello world') + + # Test reverse order + class ReverseOrderFilter(InputFilter): + text = field(required=True, validators=[IsStringValidator()]) + + # Order: first convert to lower, then trim + global_filter(ToLowerFilter()) + global_filter(StringTrimFilter()) + + reverse_filter = ReverseOrderFilter() + test_data = {'text': ' HELLO WORLD '} + validated_data = reverse_filter.validate_data(test_data) + # Should be lowercased first, then trimmed + self.assertEqual(validated_data['text'], 'hello world') + + def test_multiple_inheritance_chains(self): + """Test filter inheritance with complex hierarchies.""" + + class BaseFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + global_filter(StringTrimFilter()) + + class MiddleFilter(BaseFilter): + global_filter(ToLowerFilter()) + + class LeftFilter(MiddleFilter): + global_filter(ToPascalCaseFilter()) + + class RightFilter(MiddleFilter): + global_filter(ToSnakeCaseFilter()) + + class ComplexFilter(LeftFilter, RightFilter): + global_filter(WhitespaceCollapseFilter()) + + filter_instance = ComplexFilter() + global_filters = filter_instance.get_global_filters() + + # Should inherit filters from all parent classes + self.assertGreaterEqual(len(global_filters), 4) + + filter_types = [type(f).__name__ for f in global_filters] + self.assertIn('StringTrimFilter', filter_types) + self.assertIn('ToLowerFilter', filter_types) + self.assertIn('WhitespaceCollapseFilter', filter_types) + + def test_duplicate_filters(self): + """Test adding the same filter type multiple times.""" + + class DuplicateFiltersFilter(InputFilter): + text = field(required=True, validators=[IsStringValidator()]) + + global_filter(StringTrimFilter()) + global_filter(ToLowerFilter()) + global_filter(StringTrimFilter()) # Duplicate + + filter_instance = DuplicateFiltersFilter() + global_filters = filter_instance.get_global_filters() + + # All filters should be present (no deduplication) + self.assertEqual(len(global_filters), 3) + + filter_types = [type(f).__name__ for f in global_filters] + self.assertEqual(filter_types.count('StringTrimFilter'), 2) + self.assertEqual(filter_types.count('ToLowerFilter'), 1) + + def test_filter_interaction_different_orders(self): + """Test how different filters interact in different orders.""" + + # Test case 1: Trim then Upper + class TrimThenUpperFilter(InputFilter): + text = field(required=True, validators=[IsStringValidator()]) + + global_filter(StringTrimFilter()) + global_filter(ToUpperFilter()) + + filter1 = TrimThenUpperFilter() + test_data = {'text': ' hello world '} + result1 = filter1.validate_data(test_data) + self.assertEqual(result1['text'], 'HELLO WORLD') + + # Test case 2: Upper then Trim + class UpperThenTrimFilter(InputFilter): + text = field(required=True, validators=[IsStringValidator()]) + + global_filter(ToUpperFilter()) + global_filter(StringTrimFilter()) + + filter2 = UpperThenTrimFilter() + result2 = filter2.validate_data(test_data) + self.assertEqual(result2['text'], 'HELLO WORLD') + + # Both should give same result in this case + self.assertEqual(result1['text'], result2['text']) + + def test_global_filters_on_non_string_fields(self): + """Test global filters on numeric and other field types.""" + + class MixedTypeFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + age = field(required=True, validators=[IsIntegerValidator()]) + score = field(required=False) + + # String filter should only affect string fields + global_filter(StringTrimFilter()) + global_filter(ToLowerFilter()) + + filter_instance = MixedTypeFilter() + + test_data = { + 'name': ' JOHN DOE ', + 'age': 25, # Already integer + 'score': 95.5 + } + + validated_data = filter_instance.validate_data(test_data) + + # String field should be filtered + self.assertEqual(validated_data['name'], 'john doe') + + # Numeric fields should not be affected by string filters + self.assertEqual(validated_data['age'], 25) # Integer should remain integer + self.assertEqual(validated_data['score'], 95.5) + + def test_filter_with_conversion_filters(self): + """Test global filters that convert data types.""" + + class ConversionFilter(InputFilter): + number_str = field(required=True) + float_str = field(required=True) + + global_filter(StringTrimFilter()) + global_filter(ToIntegerFilter()) + + filter_instance = ConversionFilter() + + test_data = { + 'number_str': ' 123 ', + 'float_str': ' 45.67 ' + } + + validated_data = filter_instance.validate_data(test_data) + + # Should trim then try to convert to integer + self.assertEqual(validated_data['number_str'], 123) + + # Float conversion might fail with ToIntegerFilter + # This tests error handling + try: + # ToIntegerFilter should handle this gracefully + self.assertIsInstance(validated_data['float_str'], (int, str)) + except ValidationError: + # Expected behavior if conversion fails + pass + + def test_whitespace_handling_filters(self): + """Test filters that handle whitespace in different ways.""" + + class WhitespaceFilter(InputFilter): + text = field(required=True, validators=[IsStringValidator()]) + + global_filter(WhitespaceCollapseFilter()) + global_filter(StringTrimFilter()) + + filter_instance = WhitespaceFilter() + + test_data = {'text': ' hello world test '} + validated_data = filter_instance.validate_data(test_data) + + # WhitespaceCollapse first, then trim + # Result should have single spaces and be trimmed + self.assertEqual(validated_data['text'], 'hello world test') + + def test_case_conversion_combinations(self): + """Test various case conversion filter combinations.""" + + class CaseConversionFilter(InputFilter): + text = field(required=True, validators=[IsStringValidator()]) + + global_filter(StringTrimFilter()) + global_filter(ToLowerFilter()) + global_filter(ToPascalCaseFilter()) + + filter_instance = CaseConversionFilter() + + test_data = {'text': ' hello_world_test '} + validated_data = filter_instance.validate_data(test_data) + + # Should trim, then lower, then pascal case + # Exact result depends on filter implementation + self.assertIsInstance(validated_data['text'], str) + self.assertTrue(len(validated_data['text']) > 0) + + def test_empty_string_handling(self): + """Test how global filters handle empty strings.""" + + class EmptyStringFilter(InputFilter): + optional_field = field(required=False, validators=[IsStringValidator()]) + + global_filter(StringTrimFilter()) + global_filter(ToLowerFilter()) + + filter_instance = EmptyStringFilter() + + test_data = {'optional_field': ''} + validated_data = filter_instance.validate_data(test_data) + + # Empty string should remain empty after filtering + self.assertEqual(validated_data['optional_field'], '') + + # Test with whitespace only + test_data = {'optional_field': ' '} + validated_data = filter_instance.validate_data(test_data) + + # Should be trimmed to empty string + self.assertEqual(validated_data['optional_field'], '') + + def test_multiple_global_filters_at_once(self): + """Test registering multiple global filters in a single call.""" + + class MultiFilterInputFilter(InputFilter): + name = field(required=True) + description = field(required=False) + + global_filter(StringTrimFilter(), ToUpperFilter()) + + filter_instance = MultiFilterInputFilter() + + global_filters = filter_instance.get_global_filters() + self.assertEqual(len(global_filters), 2) + + filter_types = [type(f) for f in global_filters] + self.assertIn(StringTrimFilter, filter_types) + self.assertIn(ToUpperFilter, filter_types) + + validated_data = filter_instance.validate_data({ + 'name': ' john doe ', 'description': ' test description ' + }) + + self.assertEqual(validated_data['name'], 'JOHN DOE') + self.assertEqual(validated_data['description'], 'TEST DESCRIPTION') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/declarative/test_global_validator_decorator.py b/tests/declarative/test_global_validator_decorator.py new file mode 100644 index 0000000..ad88180 --- /dev/null +++ b/tests/declarative/test_global_validator_decorator.py @@ -0,0 +1,482 @@ +import unittest +from flask_inputfilter import InputFilter +from flask_inputfilter.declarative import field, global_validator +from flask_inputfilter.exceptions import ValidationError +from flask_inputfilter.validators import ( + IsStringValidator, LengthValidator, IsIntegerValidator, IsFloatValidator, + IsArrayValidator, InArrayValidator, IsDateValidator, AndValidator, + ArrayLengthValidator, InEnumValidator, IsInstanceValidator +) +from enum import Enum + + +class TestGlobalValidatorDecorator(unittest.TestCase): + + def test_global_validator_decorator(self): + + class TestInputFilter(InputFilter): + name = field(required=True) + email = field(required=True) + + global_validator(IsStringValidator()) + global_validator(LengthValidator(min_length=3, max_length=100)) + + filter_instance = TestInputFilter() + + global_validators = filter_instance.get_global_validators() + self.assertEqual(len(global_validators), 2) + + validator_types = [type(v) for v in global_validators] + self.assertIn(IsStringValidator, validator_types) + self.assertIn(LengthValidator, validator_types) + + valid_data = {'name': 'John', 'email': 'test@example.com'} + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['name'], 'John') + + invalid_data = {'name': 123, 'email': 456} + with self.assertRaises(ValidationError) as context: + filter_instance.validate_data(invalid_data) + + errors = context.exception.args[0] + self.assertIn('name', errors) + self.assertIn('email', errors) + + def test_global_validator_inheritance(self): + + class BaseInputFilter(InputFilter): + name = field(required=True) + + global_validator(IsStringValidator()) + + class ChildInputFilter(BaseInputFilter): + email = field(required=True) + + global_validator(LengthValidator(min_length=3, max_length=100)) + + child_filter = ChildInputFilter() + + global_validators = child_filter.get_global_validators() + self.assertEqual(len(global_validators), 2) + + validator_types = [type(v) for v in global_validators] + self.assertIn(IsStringValidator, validator_types) + self.assertIn(LengthValidator, validator_types) + + valid_data = {'name': 'John', 'email': 'test@example.com'} + validated_data = child_filter.validate_data(valid_data) + self.assertEqual(validated_data['name'], 'John') + + def test_ruf012_solved_for_global_validators(self): + + class ProblematicInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[IsStringValidator()]) + + global_validator(LengthValidator(min_length=3, max_length=100)) + + filter_instance = ProblematicInputFilter() + + global_validators = filter_instance.get_global_validators() + self.assertEqual(len(global_validators), 1) + + test_data = {'name': 'John', 'email': 'test@example.com'} + validated_data = filter_instance.validate_data(test_data) + + self.assertEqual(validated_data['name'], 'John') + self.assertEqual(validated_data['email'], 'test@example.com') + + self.assertTrue(hasattr(ProblematicInputFilter, '_global_validators')) + + def test_empty_global_validators_behavior(self): + """Test behavior when no global validators are defined.""" + + class NoGlobalValidatorsFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + filter_instance = NoGlobalValidatorsFilter() + global_validators = filter_instance.get_global_validators() + self.assertEqual(len(global_validators), 0) + + # Should validate with only field-specific validators + test_data = {'name': 'test'} + validated_data = filter_instance.validate_data(test_data) + self.assertEqual(validated_data['name'], 'test') + + # Invalid data should still fail field validation + invalid_data = {'name': 123} + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + def test_conflicting_validators(self): + """Test validators with contradictory requirements.""" + + class ConflictingValidatorsFilter(InputFilter): + value = field(required=True) + + # These validators conflict: can't be both string and integer + global_validator(IsStringValidator()) + global_validator(IsIntegerValidator()) + + filter_instance = ConflictingValidatorsFilter() + global_validators = filter_instance.get_global_validators() + self.assertEqual(len(global_validators), 2) + + # String data should fail integer validation + string_data = {'value': 'test'} + with self.assertRaises(ValidationError): + filter_instance.validate_data(string_data) + + # Integer data should fail string validation + int_data = {'value': 123} + with self.assertRaises(ValidationError): + filter_instance.validate_data(int_data) + + def test_complex_validator_combinations(self): + """Test with AND/OR validator combinations.""" + + class ComplexValidatorFilter(InputFilter): + tags = field(required=True) + description = field(required=True) + + # Must be array and have specific length + global_validator(IsArrayValidator()) + global_validator(ArrayLengthValidator(min_length=2, max_length=5)) + + filter_instance = ComplexValidatorFilter() + global_validators = filter_instance.get_global_validators() + self.assertEqual(len(global_validators), 2) + + # Valid case + valid_data = { + 'tags': ['tag1', 'tag2', 'tag3'], + 'description': ['desc1', 'desc2'] + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(len(validated_data['tags']), 3) + + # Invalid - not array + invalid_data1 = { + 'tags': 'not_array', + 'description': ['desc1', 'desc2'] + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data1) + + # Invalid - array too short + invalid_data2 = { + 'tags': ['tag1'], + 'description': ['desc1'] + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data2) + + def test_validator_ordering(self): + """Test that validators are applied in correct order.""" + + class OrderedValidatorFilter(InputFilter): + text = field(required=True) + + # Order matters: first check if string, then check length + global_validator(IsStringValidator()) + global_validator(LengthValidator(min_length=5, max_length=100)) + + filter_instance = OrderedValidatorFilter() + global_validators = filter_instance.get_global_validators() + self.assertEqual(len(global_validators), 2) + + # First validator should be IsStringValidator + self.assertIsInstance(global_validators[0], IsStringValidator) + self.assertIsInstance(global_validators[1], LengthValidator) + + # Valid case + valid_data = {'text': 'hello world'} + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['text'], 'hello world') + + # Invalid - not string (should fail first validator) + invalid_data1 = {'text': 123} + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data1) + + # Invalid - string too short (should fail second validator) + invalid_data2 = {'text': 'hi'} + try: + filter_instance.validate_data(invalid_data2) + # If this doesn't raise, that's also acceptable depending on implementation + except ValidationError: + # Expected behavior + pass + + def test_global_vs_field_validators_interaction(self): + """Test interaction between global and field-specific validators.""" + + class MixedValidatorFilter(InputFilter): + name = field(required=True, validators=[LengthValidator(min_length=1, max_length=20)]) + email = field(required=True, validators=[LengthValidator(min_length=1, max_length=50)]) + + # Global validator applies to all fields + global_validator(IsStringValidator()) + global_validator(LengthValidator(min_length=3, max_length=100)) + + filter_instance = MixedValidatorFilter() + + # Valid case - passes both global and field validators + valid_data = { + 'name': 'John Doe', + 'email': 'john@example.com' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['name'], 'John Doe') + + # Invalid - fails global string validator + invalid_data1 = { + 'name': 123, + 'email': 'john@example.com' + } + try: + filter_instance.validate_data(invalid_data1) + # Should not reach here, but if it does, validation logic may differ + self.fail("Expected ValidationError for non-string name") + except (ValidationError, TypeError): + # Expected - either validation error or type error from len() on int + pass + + # Invalid - fails global min length validator + invalid_data2 = { + 'name': 'Jo', # Too short for global validator + 'email': 'john@example.com' + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data2) + + # Invalid - fails field max length validator + invalid_data3 = { + 'name': 'Very long name that exceeds limit', # Too long for field validator + 'email': 'john@example.com' + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data3) + + def test_enum_validation_with_global_validators(self): + """Test enum validation with global validators.""" + + class Priority(Enum): + LOW = 'low' + MEDIUM = 'medium' + HIGH = 'high' + + class EnumValidatorFilter(InputFilter): + priority = field(required=True) + backup_priority = field(required=False) + + global_validator(IsStringValidator()) + global_validator(InEnumValidator(Priority)) + + filter_instance = EnumValidatorFilter() + + # Valid case + valid_data = { + 'priority': 'high', + 'backup_priority': 'low' + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['priority'], 'high') + + # Invalid - not in enum + invalid_data = { + 'priority': 'urgent', # Not in enum + 'backup_priority': 'low' + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + def test_type_validation_edge_cases(self): + """Test various type validation edge cases.""" + + class TypeValidatorFilter(InputFilter): + mixed_field = field(required=True) + + global_validator(IsInstanceValidator((str, int, float))) + + filter_instance = TypeValidatorFilter() + + # Valid cases - allowed types + for valid_value in ['string', 123, 45.67]: + valid_data = {'mixed_field': valid_value} + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['mixed_field'], valid_value) + + # Invalid cases - not allowed types + for invalid_value in [[], {}, set(), None]: + invalid_data = {'mixed_field': invalid_value} + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + def test_nested_inheritance_with_validators(self): + """Test validator inheritance with multiple levels.""" + + class BaseValidatorFilter(InputFilter): + name = field(required=True) + + global_validator(IsStringValidator()) + + class MiddleValidatorFilter(BaseValidatorFilter): + global_validator(LengthValidator(min_length=2, max_length=100)) + + class FinalValidatorFilter(MiddleValidatorFilter): + age = field(required=True) + + global_validator(LengthValidator(max_length=100)) + + filter_instance = FinalValidatorFilter() + global_validators = filter_instance.get_global_validators() + + # Should have validators from all inheritance levels + self.assertEqual(len(global_validators), 3) + + validator_types = [type(v).__name__ for v in global_validators] + self.assertIn('IsStringValidator', validator_types) + self.assertEqual(validator_types.count('LengthValidator'), 2) + + # Test with string values to avoid len() issues with global LengthValidator + valid_data = { + 'name': 'John', + 'age': 'twenty-five' # Use string to avoid type issues with global validators + } + try: + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['name'], 'John') + except (ValidationError, TypeError): + # If this fails due to validator issues, it's acceptable for this edge case test + pass + + def test_array_validation_with_global_validators(self): + """Test array-specific validation scenarios.""" + + class ArrayValidatorFilter(InputFilter): + items = field(required=True) + categories = field(required=False) + + global_validator(IsArrayValidator()) + global_validator(ArrayLengthValidator(min_length=1, max_length=10)) + + filter_instance = ArrayValidatorFilter() + + # Valid case + valid_data = { + 'items': ['item1', 'item2'], + 'categories': ['cat1'] + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(len(validated_data['items']), 2) + + # Invalid - empty array violates min_length + invalid_data1 = { + 'items': [], + 'categories': ['cat1'] + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data1) + + # Invalid - too many items + invalid_data2 = { + 'items': [f'item{i}' for i in range(15)], # 15 items > max 10 + 'categories': ['cat1'] + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data2) + + def test_date_validation_with_global_validators(self): + """Test date validation scenarios.""" + + class DateValidatorFilter(InputFilter): + start_date = field(required=True) + end_date = field(required=False) + + global_validator(IsDateValidator()) + + filter_instance = DateValidatorFilter() + + # Valid case with date objects + from datetime import date + valid_data = { + 'start_date': date(2023, 1, 1), + 'end_date': date(2023, 12, 31) + } + validated_data = filter_instance.validate_data(valid_data) + self.assertEqual(validated_data['start_date'], date(2023, 1, 1)) + + # Invalid - not a date + invalid_data = { + 'start_date': 'not-a-date', + 'end_date': date(2023, 12, 31) + } + with self.assertRaises(ValidationError): + filter_instance.validate_data(invalid_data) + + def test_multiple_validation_errors(self): + """Test handling of multiple validation errors.""" + + class MultiErrorFilter(InputFilter): + field1 = field(required=True) + field2 = field(required=True) + field3 = field(required=True) + + global_validator(IsStringValidator()) + global_validator(LengthValidator(min_length=5, max_length=100)) + + filter_instance = MultiErrorFilter() + + # All fields should fail validation + invalid_data = { + 'field1': 123, # Not string + 'field2': 'hi', # Too short + 'field3': [] # Not string + } + + with self.assertRaises(ValidationError) as context: + filter_instance.validate_data(invalid_data) + + # Should have errors for multiple fields + errors = context.exception.args[0] + self.assertIsInstance(errors, dict) + # Should have errors for at least some fields + self.assertGreater(len(errors), 0) + + def test_multiple_global_validators_at_once(self): + """Test registering multiple global validators in a single call.""" + + class MultiValidatorFilter(InputFilter): + name = field(required=True) + email = field(required=True) + + global_validator(IsStringValidator(), LengthValidator(min_length=3, max_length=100)) + + filter_instance = MultiValidatorFilter() + + global_validators = filter_instance.get_global_validators() + self.assertEqual(len(global_validators), 2) + + validator_types = [type(v) for v in global_validators] + self.assertIn(IsStringValidator, validator_types) + self.assertIn(LengthValidator, validator_types) + + validated_data = filter_instance.validate_data({ + 'name': 'John', 'email': 'test@example.com' + }) + self.assertEqual(validated_data['name'], 'John') + + with self.assertRaises(ValidationError): + filter_instance.validate_data({ + 'name': 123, 'email': 456 + }) + + with self.assertRaises(ValidationError): + filter_instance.validate_data({ + 'name': 'ab', 'email': 'x' + }) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/declarative/test_model_decorator.py b/tests/declarative/test_model_decorator.py new file mode 100644 index 0000000..75eb8a9 --- /dev/null +++ b/tests/declarative/test_model_decorator.py @@ -0,0 +1,401 @@ +import unittest +from dataclasses import dataclass +from typing import Optional, Dict, Any +from flask_inputfilter import InputFilter +from flask_inputfilter.declarative import field, model +from flask_inputfilter.validators import IsStringValidator, IsIntegerValidator +from flask_inputfilter.exceptions import ValidationError + +try: + from typing import TypedDict + TYPEDDICT_AVAILABLE = True +except ImportError: + try: + from typing_extensions import TypedDict + TYPEDDICT_AVAILABLE = True + except ImportError: + TYPEDDICT_AVAILABLE = False + +try: + from pydantic import BaseModel + PYDANTIC_AVAILABLE = True +except ImportError: + PYDANTIC_AVAILABLE = False + + +class TestModelDecorator(unittest.TestCase): + + def test_model_decorator(self): + + @dataclass + class TestModel: + name: str + email: str + + class TestInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + email = field(required=True, validators=[IsStringValidator()]) + + model(TestModel) + + filter_instance = TestInputFilter() + + self.assertEqual(filter_instance.model_class, TestModel) + + test_data = {'name': 'John', 'email': 'john@example.com'} + validated_data = filter_instance.validate_data(test_data) + + self.assertIsInstance(validated_data, TestModel) + self.assertEqual(validated_data.name, 'John') + self.assertEqual(validated_data.email, 'john@example.com') + + def test_model_inheritance(self): + + @dataclass + class BaseModel: + name: str + + @dataclass + class ExtendedModel: + name: str + age: int + + class BaseInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + model(BaseModel) + + class ChildInputFilter(BaseInputFilter): + age = field(required=True) + + model(ExtendedModel) + + child_filter = ChildInputFilter() + + self.assertEqual(child_filter.model_class, ExtendedModel) + + validated_data = child_filter.validate_data({'name': 'John', 'age': 25}) + + self.assertIsInstance(validated_data, ExtendedModel) + self.assertEqual(validated_data.name, 'John') + self.assertEqual(validated_data.age, 25) + + def test_model_backward_compatibility(self): + + @dataclass + class TestModel: + name: str + + class OldStyleInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + _model = TestModel + + class NewStyleInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + model(TestModel) + + old_filter = OldStyleInputFilter() + new_filter = NewStyleInputFilter() + + self.assertEqual(old_filter.model_class, TestModel) + self.assertEqual(new_filter.model_class, TestModel) + + test_data = {'name': 'John'} + + old_validated = old_filter.validate_data(test_data) + new_validated = new_filter.validate_data(test_data) + + self.assertIsInstance(old_validated, TestModel) + self.assertIsInstance(new_validated, TestModel) + self.assertEqual(old_validated.name, 'John') + self.assertEqual(new_validated.name, 'John') + + def test_ruf012_solved_for_model(self): + + @dataclass + class ProblematicModel: + name: str + + class ProblematicInputFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + model(ProblematicModel) + + filter_instance = ProblematicInputFilter() + + self.assertEqual(filter_instance.model_class, ProblematicModel) + + validated_data = filter_instance.validate_data({'name': 'John'}) + + self.assertIsInstance(validated_data, ProblematicModel) + self.assertEqual(validated_data.name, 'John') + + self.assertTrue(hasattr(ProblematicInputFilter, '_model')) + + def test_missing_model_class_behavior(self): + """Test behavior when no model decorator is used.""" + + class NoModelFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + filter_instance = NoModelFilter() + + self.assertIsNone(getattr(filter_instance, 'model_class', None)) + + validated_data = filter_instance.validate_data({'name': 'John'}) + self.assertEqual(validated_data['name'], 'John') + + self.assertIsInstance(validated_data, dict) + self.assertEqual(validated_data['name'], 'John') + + @unittest.skipUnless(TYPEDDICT_AVAILABLE, "TypedDict not available") + def test_typeddict_models(self): + """Test with TypedDict instead of dataclass.""" + + class UserDict(TypedDict): + name: str + age: int + + class TypedDictFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + age = field(required=True, validators=[IsIntegerValidator()]) + + model(UserDict) + + filter_instance = TypedDictFilter() + self.assertEqual(filter_instance.model_class, UserDict) + + validated_data = filter_instance.validate_data({'name': 'John', 'age': 25}) + + self.assertIsInstance(validated_data, dict) + self.assertEqual(validated_data['name'], 'John') + self.assertEqual(validated_data['age'], 25) + + @unittest.skipUnless(PYDANTIC_AVAILABLE, "Pydantic not available") + def test_pydantic_models(self): + """Test with Pydantic model classes.""" + + class PydanticUser(BaseModel): + name: str + age: int + email: Optional[str] = None + + class PydanticFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + age = field(required=True, validators=[IsIntegerValidator()]) + email = field(required=False, validators=[IsStringValidator()]) + + model(PydanticUser) + + filter_instance = PydanticFilter() + self.assertEqual(filter_instance.model_class, PydanticUser) + + validated_data = filter_instance.validate_data({ + 'name': 'John', 'age': 25, 'email': 'john@example.com' + }) + + self.assertIsInstance(validated_data, PydanticUser) + self.assertEqual(validated_data.name, 'John') + self.assertEqual(validated_data.age, 25) + self.assertEqual(validated_data.email, 'john@example.com') + + def test_multiple_model_decorators_error(self): + """Test error handling with multiple model() calls.""" + + @dataclass + class FirstModel: + name: str + + @dataclass + class SecondModel: + title: str + + class MultipleModelFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + model(FirstModel) + model(SecondModel) + + filter_instance = MultipleModelFilter() + + self.assertEqual(filter_instance.model_class, SecondModel) + + def test_serialization_errors(self): + """Test handling of serialization failures.""" + + @dataclass + class StrictModel: + name: str + required_field: str + + class SerializationErrorFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + model(StrictModel) + + filter_instance = SerializationErrorFilter() + + with self.assertRaises(TypeError): + filter_instance.validate_data({'name': 'John'}) + + def test_partial_field_coverage(self): + """Test models with subset of filter fields.""" + + @dataclass + class PartialModel: + name: str + + class PartialCoverageFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + age = field(required=False, validators=[IsIntegerValidator()]) + email = field(required=False, validators=[IsStringValidator()]) + + model(PartialModel) + + filter_instance = PartialCoverageFilter() + + serialized = filter_instance.validate_data({'name': 'John'}) + + self.assertIsInstance(serialized, PartialModel) + self.assertEqual(serialized.name, 'John') + + def test_model_with_optional_fields(self): + """Test models with optional fields.""" + + @dataclass + class OptionalFieldsModel: + name: str + age: Optional[int] = None + email: Optional[str] = None + + class OptionalFieldsFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + age = field(required=False, validators=[IsIntegerValidator()]) + email = field(required=False, validators=[IsStringValidator()]) + + model(OptionalFieldsModel) + + filter_instance = OptionalFieldsFilter() + + validated_data1 = filter_instance.validate_data({'name': 'John', 'age': 25, 'email': 'john@example.com'}) + self.assertEqual(validated_data1.name, 'John') + self.assertEqual(validated_data1.age, 25) + self.assertEqual(validated_data1.email, 'john@example.com') + + validated_data2 = filter_instance.validate_data({'name': 'Jane'}) + self.assertEqual(validated_data2.name, 'Jane') + self.assertIsNone(validated_data2.age) + self.assertIsNone(validated_data2.email) + + def test_complex_model_types(self): + """Test with complex model field types.""" + + @dataclass + class ComplexModel: + name: str + metadata: Dict[str, Any] + tags: list + + class ComplexModelFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + metadata = field(required=False) + tags = field(required=False) + + model(ComplexModel) + + filter_instance = ComplexModelFilter() + + validated_data = filter_instance.validate_data({ + 'name': 'Test', + 'metadata': {'key1': 'value1', 'key2': 42}, + 'tags': ['tag1', 'tag2'] + }) + + self.assertEqual(validated_data.name, 'Test') + self.assertEqual(validated_data.metadata, {'key1': 'value1', 'key2': 42}) + self.assertEqual(validated_data.tags, ['tag1', 'tag2']) + + def test_model_inheritance_overrides(self): + """Test model decorator inheritance and overrides.""" + + @dataclass + class BaseUserModel: + name: str + + @dataclass + class ExtendedUserModel: + name: str + age: int + role: str + + class BaseModelFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + + model(BaseUserModel) + + class ExtendedModelFilter(BaseModelFilter): + age = field(required=True, validators=[IsIntegerValidator()]) + role = field(required=True, validators=[IsStringValidator()]) + + model(ExtendedUserModel) + + filter_instance = ExtendedModelFilter() + + self.assertEqual(filter_instance.model_class, ExtendedUserModel) + + test_data = {'name': 'John', 'age': 25, 'role': 'admin'} + validated_data = filter_instance.validate_data(test_data) + + self.assertIsInstance(validated_data, ExtendedUserModel) + self.assertEqual(validated_data.name, 'John') + self.assertEqual(validated_data.age, 25) + self.assertEqual(validated_data.role, 'admin') + + def test_model_with_nested_objects(self): + """Test models with nested object structures.""" + + @dataclass + class Address: + street: str + city: str + + @dataclass + class UserWithAddress: + name: str + address: Address + + class NestedModelFilter(InputFilter): + name = field(required=True, validators=[IsStringValidator()]) + address = field(required=True) + + model(UserWithAddress) + + filter_instance = NestedModelFilter() + + validated_data = filter_instance.validate_data({ + 'name': 'John', + 'address': {'street': '123 Main St', 'city': 'Anytown'} + }) + self.assertEqual(validated_data.name, 'John') + + def test_empty_model_class(self): + """Test with empty dataclass model.""" + + @dataclass + class EmptyModel: + pass + + class EmptyModelFilter(InputFilter): + model(EmptyModel) + + filter_instance = EmptyModelFilter() + self.assertEqual(filter_instance.model_class, EmptyModel) + + validated_data = filter_instance.validate_data({}) + self.assertIsInstance(validated_data, EmptyModel) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/filters/test_to_typed_dict_filter.py b/tests/filters/test_to_typed_dict_filter.py index 81700b9..4fb1d75 100644 --- a/tests/filters/test_to_typed_dict_filter.py +++ b/tests/filters/test_to_typed_dict_filter.py @@ -3,28 +3,39 @@ from flask_inputfilter import InputFilter from flask_inputfilter.filters import ToTypedDictFilter +try: + from typing import TypedDict + TYPEDDICT_AVAILABLE = True +except ImportError: + try: + from typing_extensions import TypedDict + TYPEDDICT_AVAILABLE = True + except ImportError: + TYPEDDICT_AVAILABLE = False + class TestToTypedDictFilter(unittest.TestCase): def setUp(self) -> None: self.input_filter = InputFilter() def test_converts_dict_to_typed_dict(self) -> None: - # TODO: Readd when Python 3.7 support is dropped - # class Person(TypedDict): - # name: str - # age: int - - class Person: - __annotations__ = {"name": str, "age": int} + if TYPEDDICT_AVAILABLE: + class Person(TypedDict): + name: str + age: int + else: + # Fallback for when TypedDict is not available + class Person: + __annotations__ = {"name": str, "age": int} - def __init__(self, name: str, age: int) -> None: - self.name = name - self.age = age + def __init__(self, name: str, age: int) -> None: + self.name = name + self.age = age - def __eq__(self, other): - if isinstance(other, dict): - return other == {"name": self.name, "age": self.age} - return NotImplemented + def __eq__(self, other): + if isinstance(other, dict): + return other == {"name": self.name, "age": self.age} + return NotImplemented self.input_filter.add( "person", required=True, filters=[ToTypedDictFilter(Person)] @@ -36,12 +47,18 @@ def __eq__(self, other): self.assertEqual(validated_data["person"], {"name": "John", "age": 25}) def test_non_dict_input_remains_unchanged(self) -> None: - class Person: - __annotations__ = {"name": str, "age": int} + if TYPEDDICT_AVAILABLE: + class Person(TypedDict): + name: str + age: int + else: + # Fallback for when TypedDict is not available + class Person: + __annotations__ = {"name": str, "age": int} - def __init__(self, name: str, age: int) -> None: - self.name = name - self.age = age + def __init__(self, name: str, age: int) -> None: + self.name = name + self.age = age self.input_filter.add( "person", required=True, filters=[ToTypedDictFilter(Person)] diff --git a/tests/input_filter/test_decorator_input_filter.py b/tests/input_filter/test_decorator_input_filter.py index 33ad636..745c81b 100644 --- a/tests/input_filter/test_decorator_input_filter.py +++ b/tests/input_filter/test_decorator_input_filter.py @@ -4,7 +4,7 @@ from flask import Flask, g, jsonify from flask_inputfilter import InputFilter -from flask_inputfilter.declarative import field +from flask_inputfilter.declarative import condition, field, global_filter, global_validator from flask_inputfilter.conditions import ExactlyOneOfCondition from flask_inputfilter.exceptions import ValidationError from flask_inputfilter.filters import ( @@ -125,7 +125,7 @@ class TestInputFilter(InputFilter): field1: str = field() field2: str = field() - _global_validators = [IsStringValidator()] + global_validator(IsStringValidator()) filter_instance = TestInputFilter() @@ -151,7 +151,7 @@ class TestInputFilter(InputFilter): field1: str = field() field2: str = field() - _global_filters = [ToUpperFilter()] + global_filter(ToUpperFilter()) filter_instance = TestInputFilter() @@ -170,7 +170,7 @@ class TestInputFilter(InputFilter): phone: str = field() email: str = field() - _conditions = [ExactlyOneOfCondition(['phone', 'email'])] + condition(ExactlyOneOfCondition(['phone', 'email'])) filter_instance = TestInputFilter() @@ -190,11 +190,11 @@ def test_inheritance_with_decorators(self): class BaseInputFilter(InputFilter): name: str = field(required=True) - _global_filters = [StringTrimFilter()] + global_filter(StringTrimFilter()) class ExtendedInputFilter(BaseInputFilter): age: int = field(required=True, validators=[IsIntegerValidator()]) - _global_validators = [IsStringValidator()] + global_validator(IsStringValidator()) filter_instance = ExtendedInputFilter() @@ -420,12 +420,9 @@ class TestInputFilter(InputFilter): self.assertEqual(validated_data['username'], 'testuser') self.assertTrue(validated_data['is_valid']) - # Performance and Edge Cases - def test_large_number_of_decorator_fields(self): """Test performance with many decorator fields.""" - # Create a class with many fields programmatically class_dict = {} for i in range(100): class_dict[f'field_{i}'] = field(default=f'value_{i}') @@ -468,20 +465,15 @@ class ComplexInputFilter(InputFilter): validators=[IsIntegerValidator()] ) - # Global components - _global_filters = [StringTrimFilter()] - _conditions = [ - # Custom condition example would go here if needed - ] + global_filter(StringTrimFilter()) def __init__(self): super().__init__() - # Classic API additions + self.add('email', required=False, validators=[IsStringValidator()]) filter_instance = ComplexInputFilter() - # Test successful complex validation validated_data = filter_instance.validate_data({ 'username': ' TestUser ', 'age': '25', diff --git a/tests/input_filter/test_input_filter.py b/tests/input_filter/test_input_filter.py index 5948c9c..72eaacf 100644 --- a/tests/input_filter/test_input_filter.py +++ b/tests/input_filter/test_input_filter.py @@ -1131,33 +1131,6 @@ def test_copy(self) -> None: ) self.assertEqual(validated_data["escapedUsername"], "test-user") - def test_serialize_and_set_model(self) -> None: - """Test that InputFilter.serialize() serializes the validated data.""" - - class User: - def __init__(self, username: str): - self.username = username - - @dataclass - class User2: - username: str - - self.inputFilter.add("username") - self.inputFilter.set_data({"username": "test user"}) - - self.inputFilter.is_valid() - - self.inputFilter.set_model(User) - self.assertEqual(self.inputFilter.serialize().username, "test user") - - self.inputFilter.set_model(None) - self.assertEqual( - self.inputFilter.serialize(), {"username": "test user"} - ) - - self.inputFilter.set_model(User2) - self.assertEqual(self.inputFilter.serialize().username, "test user") - def test_model_class_serialisation(self) -> None: """Test that the model class is serialized correctly.""" diff --git a/tests/input_filter/test_mixed_api.py b/tests/input_filter/test_mixed_api.py index 68f1cac..80a0b11 100644 --- a/tests/input_filter/test_mixed_api.py +++ b/tests/input_filter/test_mixed_api.py @@ -2,7 +2,7 @@ from flask import Flask, g, jsonify from flask_inputfilter import InputFilter -from flask_inputfilter.declarative import field +from flask_inputfilter.declarative import condition, field, global_filter, global_validator from flask_inputfilter.conditions import ExactlyOneOfCondition from flask_inputfilter.exceptions import ValidationError from flask_inputfilter.filters import StringTrimFilter, ToLowerFilter, ToUpperFilter @@ -62,8 +62,8 @@ class MixedInputFilter(InputFilter): field2: str = field() # Decorator-based global components - _global_filters = [StringTrimFilter()] - _global_validators = [IsStringValidator()] + global_filter(StringTrimFilter()) + global_validator(IsStringValidator()) def __init__(self): super().__init__() @@ -94,8 +94,7 @@ class MixedInputFilter(InputFilter): email: str = field() address: str = field() - # Decorator-based condition - _conditions = [ExactlyOneOfCondition(['phone', 'email'])] + condition(ExactlyOneOfCondition(['phone', 'email'])) def __init__(self): super().__init__() @@ -191,7 +190,7 @@ def test_inheritance_with_mixed_api(self): class BaseInputFilter(InputFilter): # Base decorator field name: str = field(required=True) - _global_filters = [StringTrimFilter()] + global_filter(StringTrimFilter()) def __init__(self): super().__init__() @@ -271,7 +270,7 @@ class ComplexMixedInputFilter(InputFilter): profile_type: str = field(required=True) # Global components via decorator - _global_filters = [StringTrimFilter()] + global_filter(StringTrimFilter()) def __init__(self): super().__init__() @@ -525,7 +524,7 @@ class ComplexMixedInputFilter(InputFilter): ) _model = ComplexModel - _global_filters = [StringTrimFilter()] + global_filter(StringTrimFilter()) def __init__(self): super().__init__() diff --git a/tests/validators/test_is_typed_dict_validator.py b/tests/validators/test_is_typed_dict_validator.py index 46e4ccd..72e1358 100644 --- a/tests/validators/test_is_typed_dict_validator.py +++ b/tests/validators/test_is_typed_dict_validator.py @@ -1,23 +1,34 @@ +import unittest from flask_inputfilter.exceptions import ValidationError from flask_inputfilter.validators import IsTypedDictValidator from tests.validators import BaseValidatorTest -# TODO: Readd when Python 3.7 support is dropped -# class User(TypedDict): -# id: int - - -class User: - __annotations__ = {"id": int} - - def __init__(self, id: int): - self.id = id - - def __eq__(self, other): - if isinstance(other, dict): - return other == {"id": self.id} - return NotImplemented +try: + from typing import TypedDict + TYPEDDICT_AVAILABLE = True +except ImportError: + try: + from typing_extensions import TypedDict + TYPEDDICT_AVAILABLE = True + except ImportError: + TYPEDDICT_AVAILABLE = False + +if TYPEDDICT_AVAILABLE: + class User(TypedDict): + id: int +else: + # Fallback for when TypedDict is not available + class User: + __annotations__ = {"id": int} + + def __init__(self, id: int): + self.id = id + + def __eq__(self, other): + if isinstance(other, dict): + return other == {"id": self.id} + return NotImplemented class TestIsTypedDictValidator(BaseValidatorTest):