|
| 1 | +import unittest |
| 2 | + |
| 3 | +from flask_inputfilter import InputFilter |
| 4 | +from flask_inputfilter.declarative import field |
| 5 | +from flask_inputfilter.conditions import ExactlyOneOfCondition |
| 6 | +from flask_inputfilter.exceptions import ValidationError |
| 7 | +from flask_inputfilter.filters import ToLowerFilter, ToIntegerFilter, StringTrimFilter |
| 8 | +from flask_inputfilter.validators import IsStringValidator, IsIntegerValidator, LengthValidator |
| 9 | + |
| 10 | + |
| 11 | +class TestInputFilterInheritance(unittest.TestCase): |
| 12 | + """Test suite for InputFilter inheritance behavior.""" |
| 13 | + |
| 14 | + def test_basic_inheritance(self): |
| 15 | + """Test basic inheritance where child class adds new fields.""" |
| 16 | + |
| 17 | + class UserInputFilter(InputFilter): |
| 18 | + username = field(required=True, validators=[IsStringValidator()]) |
| 19 | + email = field(required=True, validators=[IsStringValidator()]) |
| 20 | + |
| 21 | + class ProfileInputFilter(UserInputFilter): |
| 22 | + bio = field(required=False, default='No bio') |
| 23 | + age = field(required=False, validators=[IsIntegerValidator()]) |
| 24 | + |
| 25 | + profile_filter = ProfileInputFilter() |
| 26 | + |
| 27 | + # Check that all fields are present (inherited + new) |
| 28 | + expected_fields = {'username', 'email', 'bio', 'age'} |
| 29 | + actual_fields = set(profile_filter.fields.keys()) |
| 30 | + self.assertEqual(expected_fields, actual_fields) |
| 31 | + |
| 32 | + # Check field properties are preserved |
| 33 | + username_field = profile_filter.get_input('username') |
| 34 | + bio_field = profile_filter.get_input('bio') |
| 35 | + |
| 36 | + self.assertTrue(username_field.required) |
| 37 | + self.assertFalse(bio_field.required) |
| 38 | + self.assertEqual(bio_field.default, 'No bio') |
| 39 | + |
| 40 | + def test_field_overriding(self): |
| 41 | + """Test that child classes can override parent field definitions.""" |
| 42 | + |
| 43 | + class BaseInputFilter(InputFilter): |
| 44 | + name = field(required=True, validators=[IsStringValidator()]) |
| 45 | + value = field(required=False) |
| 46 | + |
| 47 | + class EnhancedInputFilter(BaseInputFilter): |
| 48 | + # Override with additional filters and validators |
| 49 | + name = field( |
| 50 | + required=True, |
| 51 | + validators=[IsStringValidator(), LengthValidator(min_length=3)], |
| 52 | + filters=[StringTrimFilter(), ToLowerFilter()] |
| 53 | + ) |
| 54 | + # Add new field |
| 55 | + description = field(required=False, default='') |
| 56 | + |
| 57 | + enhanced_filter = EnhancedInputFilter() |
| 58 | + |
| 59 | + # Check that overridden field has new properties |
| 60 | + name_field = enhanced_filter.get_input('name') |
| 61 | + self.assertEqual(len(name_field.validators), 2) |
| 62 | + self.assertEqual(len(name_field.filters), 2) |
| 63 | + |
| 64 | + # Check that both inherited and new fields exist |
| 65 | + self.assertTrue(enhanced_filter.has('name')) |
| 66 | + self.assertTrue(enhanced_filter.has('value')) |
| 67 | + self.assertTrue(enhanced_filter.has('description')) |
| 68 | + |
| 69 | + def test_multi_level_inheritance(self): |
| 70 | + """Test inheritance through multiple levels.""" |
| 71 | + |
| 72 | + class BaseInputFilter(InputFilter): |
| 73 | + id = field(required=True, validators=[IsIntegerValidator()]) |
| 74 | + |
| 75 | + class UserInputFilter(BaseInputFilter): |
| 76 | + username = field(required=True, validators=[IsStringValidator()]) |
| 77 | + email = field(required=True, validators=[IsStringValidator()]) |
| 78 | + |
| 79 | + class AdminInputFilter(UserInputFilter): |
| 80 | + role = field(required=True, default='admin') |
| 81 | + permissions = field(required=False, default=[]) |
| 82 | + |
| 83 | + admin_filter = AdminInputFilter() |
| 84 | + |
| 85 | + # Check all fields from all inheritance levels are present |
| 86 | + expected_fields = {'id', 'username', 'email', 'role', 'permissions'} |
| 87 | + actual_fields = set(admin_filter.fields.keys()) |
| 88 | + self.assertEqual(expected_fields, actual_fields) |
| 89 | + |
| 90 | + # Verify field count |
| 91 | + self.assertEqual(admin_filter.count(), 5) |
| 92 | + |
| 93 | + def test_validation_with_inherited_fields(self): |
| 94 | + """Test that validation works correctly with inherited fields.""" |
| 95 | + |
| 96 | + class UserInputFilter(InputFilter): |
| 97 | + username = field(required=True, validators=[IsStringValidator()]) |
| 98 | + email = field(required=True, validators=[IsStringValidator()]) |
| 99 | + |
| 100 | + class ProfileInputFilter(UserInputFilter): |
| 101 | + bio = field(required=False, default='Developer') |
| 102 | + age = field(required=False, validators=[IsIntegerValidator()]) |
| 103 | + |
| 104 | + profile_filter = ProfileInputFilter() |
| 105 | + |
| 106 | + # Test successful validation with all field types |
| 107 | + test_data = { |
| 108 | + 'username': 'john_doe', |
| 109 | + |
| 110 | + 'bio': 'Senior Developer', |
| 111 | + 'age': 30 |
| 112 | + } |
| 113 | + |
| 114 | + validated_data = profile_filter.validate_data(test_data) |
| 115 | + self.assertEqual(validated_data['username'], 'john_doe') |
| 116 | + self. assertEqual( validated_data[ 'email'], '[email protected]') |
| 117 | + self.assertEqual(validated_data['bio'], 'Senior Developer') |
| 118 | + self.assertEqual(validated_data['age'], 30) |
| 119 | + |
| 120 | + # Test validation with missing optional fields (should use defaults) |
| 121 | + minimal_data = { |
| 122 | + 'username': 'jane_doe', |
| 123 | + |
| 124 | + } |
| 125 | + |
| 126 | + validated_data = profile_filter.validate_data(minimal_data) |
| 127 | + self.assertEqual(validated_data['bio'], 'Developer') # default value |
| 128 | + |
| 129 | + # Test validation failure for required inherited field |
| 130 | + invalid_data = { |
| 131 | + |
| 132 | + # missing required 'username' |
| 133 | + } |
| 134 | + |
| 135 | + with self.assertRaises(ValidationError) as context: |
| 136 | + profile_filter.validate_data(invalid_data) |
| 137 | + |
| 138 | + errors = context.exception.args[0] |
| 139 | + self.assertIn('username', errors) |
| 140 | + |
| 141 | + def test_conditions_inheritance(self): |
| 142 | + """Test that conditions can be inherited and work with inherited |
| 143 | + fields.""" |
| 144 | + |
| 145 | + class BaseInputFilter(InputFilter): |
| 146 | + field_a = field(required=False) |
| 147 | + field_b = field(required=False) |
| 148 | + |
| 149 | + def __init__(self): |
| 150 | + super().__init__() |
| 151 | + self.add_condition(ExactlyOneOfCondition(['field_a', 'field_b'])) |
| 152 | + |
| 153 | + class ExtendedInputFilter(BaseInputFilter): |
| 154 | + field_c = field(required=False) |
| 155 | + |
| 156 | + extended_filter = ExtendedInputFilter() |
| 157 | + |
| 158 | + # Check that condition is inherited |
| 159 | + conditions = extended_filter.get_conditions() |
| 160 | + self.assertEqual(len(conditions), 1) |
| 161 | + self.assertIsInstance(conditions[0], ExactlyOneOfCondition) |
| 162 | + |
| 163 | + # Test that inherited condition works |
| 164 | + valid_data = {'field_a': 'value1'} |
| 165 | + validated_data = extended_filter.validate_data(valid_data) |
| 166 | + self.assertEqual(validated_data['field_a'], 'value1') |
| 167 | + |
| 168 | + # Test condition violation |
| 169 | + invalid_data = {'field_a': 'value1', 'field_b': 'value2'} |
| 170 | + with self.assertRaises(ValidationError): |
| 171 | + extended_filter.validate_data(invalid_data) |
| 172 | + |
| 173 | + def test_global_filters_and_validators_inheritance(self): |
| 174 | + """Test that global filters and validators are inherited.""" |
| 175 | + |
| 176 | + class BaseInputFilter(InputFilter): |
| 177 | + name = field(required=True) |
| 178 | + |
| 179 | + def __init__(self): |
| 180 | + super().__init__() |
| 181 | + self.add_global_filter(StringTrimFilter()) |
| 182 | + self.add_global_validator(IsStringValidator()) |
| 183 | + |
| 184 | + class ChildInputFilter(BaseInputFilter): |
| 185 | + description = field(required=False, default='No description') |
| 186 | + |
| 187 | + child_filter = ChildInputFilter() |
| 188 | + |
| 189 | + # Check that global filters and validators are inherited |
| 190 | + global_filters = child_filter.get_global_filters() |
| 191 | + global_validators = child_filter.get_global_validators() |
| 192 | + |
| 193 | + self.assertEqual(len(global_filters), 1) |
| 194 | + self.assertEqual(len(global_validators), 1) |
| 195 | + self.assertIsInstance(global_filters[0], StringTrimFilter) |
| 196 | + self.assertIsInstance(global_validators[0], IsStringValidator) |
| 197 | + |
| 198 | + # Test that global filter works on inherited fields |
| 199 | + test_data = {'name': ' john ', 'description': ' A developer '} |
| 200 | + validated_data = child_filter.validate_data(test_data) |
| 201 | + self.assertEqual(validated_data['name'], 'john') # trimmed by global filter |
| 202 | + self.assertEqual(validated_data['description'], 'A developer') # also trimmed |
| 203 | + |
| 204 | + def test_inheritance_with_programmatic_fields(self): |
| 205 | + """Test inheritance when parent uses programmatic field addition.""" |
| 206 | + |
| 207 | + class ProgrammaticInputFilter(InputFilter): |
| 208 | + def __init__(self): |
| 209 | + super().__init__() |
| 210 | + self.add('username', required=True, validators=[IsStringValidator()]) |
| 211 | + self.add('email', required=True, validators=[IsStringValidator()]) |
| 212 | + |
| 213 | + class DeclarativeChildInputFilter(ProgrammaticInputFilter): |
| 214 | + bio = field(required=False, default='No bio') |
| 215 | + age = field(required=False, validators=[IsIntegerValidator()]) |
| 216 | + |
| 217 | + child_filter = DeclarativeChildInputFilter() |
| 218 | + |
| 219 | + # Check that both programmatic and declarative fields exist |
| 220 | + expected_fields = {'username', 'email', 'bio', 'age'} |
| 221 | + actual_fields = set(child_filter.fields.keys()) |
| 222 | + self.assertEqual(expected_fields, actual_fields) |
| 223 | + |
| 224 | + # Test validation works for both types |
| 225 | + test_data = { |
| 226 | + 'username': 'test_user', |
| 227 | + |
| 228 | + 'bio': 'Developer', |
| 229 | + 'age': 28 |
| 230 | + } |
| 231 | + |
| 232 | + validated_data = child_filter.validate_data(test_data) |
| 233 | + self.assertEqual(len(validated_data), 4) |
| 234 | + self.assertEqual(validated_data['username'], 'test_user') |
| 235 | + self.assertEqual(validated_data['bio'], 'Developer') |
| 236 | + |
| 237 | + def test_deep_inheritance_chain(self): |
| 238 | + """Test a deep inheritance chain to ensure all levels work |
| 239 | + correctly.""" |
| 240 | + |
| 241 | + class Level1InputFilter(InputFilter): |
| 242 | + level1_field = field(required=False, default='level1') |
| 243 | + |
| 244 | + class Level2InputFilter(Level1InputFilter): |
| 245 | + level2_field = field(required=False, default='level2') |
| 246 | + |
| 247 | + class Level3InputFilter(Level2InputFilter): |
| 248 | + level3_field = field(required=False, default='level3') |
| 249 | + |
| 250 | + class Level4InputFilter(Level3InputFilter): |
| 251 | + level4_field = field(required=False, default='level4') |
| 252 | + |
| 253 | + deep_filter = Level4InputFilter() |
| 254 | + |
| 255 | + # Check all fields from all levels are present |
| 256 | + expected_fields = {'level1_field', 'level2_field', 'level3_field', 'level4_field'} |
| 257 | + actual_fields = set(deep_filter.fields.keys()) |
| 258 | + self.assertEqual(expected_fields, actual_fields) |
| 259 | + |
| 260 | + # Test validation with all defaults |
| 261 | + validated_data = deep_filter.validate_data({}) |
| 262 | + self.assertEqual(validated_data['level1_field'], 'level1') |
| 263 | + self.assertEqual(validated_data['level2_field'], 'level2') |
| 264 | + self.assertEqual(validated_data['level3_field'], 'level3') |
| 265 | + self.assertEqual(validated_data['level4_field'], 'level4') |
| 266 | + |
| 267 | + def test_inheritance_does_not_modify_parent_class(self): |
| 268 | + """Test that creating child classes doesn't modify parent class.""" |
| 269 | + |
| 270 | + class ParentInputFilter(InputFilter): |
| 271 | + parent_field = field(required=True) |
| 272 | + |
| 273 | + # Create instance of parent before child class definition |
| 274 | + parent_instance_before = ParentInputFilter() |
| 275 | + |
| 276 | + class ChildInputFilter(ParentInputFilter): |
| 277 | + child_field = field(required=True) |
| 278 | + |
| 279 | + # Create instance of parent after child class definition |
| 280 | + parent_instance_after = ParentInputFilter() |
| 281 | + child_instance = ChildInputFilter() |
| 282 | + |
| 283 | + # Parent instances should only have parent fields |
| 284 | + self.assertEqual(set(parent_instance_before.fields.keys()), {'parent_field'}) |
| 285 | + self.assertEqual(set(parent_instance_after.fields.keys()), {'parent_field'}) |
| 286 | + |
| 287 | + # Child instance should have both fields |
| 288 | + self.assertEqual(set(child_instance.fields.keys()), {'parent_field', 'child_field'}) |
| 289 | + |
| 290 | + |
| 291 | +if __name__ == '__main__': |
| 292 | + unittest.main() |
0 commit comments