Skip to content

Commit 0dc58e2

Browse files
committed
Add tests for inheritance
1 parent 016f406 commit 0dc58e2

File tree

5 files changed

+292
-0
lines changed

5 files changed

+292
-0
lines changed
File renamed without changes.
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import unittest
2+
3+
from flask_inputfilter import InputFilter
4+
from flask_inputfilter.declarative import field
5+
from flask_inputfilter.conditions import ExactlyOneOfCondition
6+
from flask_inputfilter.exceptions import ValidationError
7+
from flask_inputfilter.filters import ToLowerFilter, ToIntegerFilter, StringTrimFilter
8+
from flask_inputfilter.validators import IsStringValidator, IsIntegerValidator, LengthValidator
9+
10+
11+
class TestInputFilterInheritance(unittest.TestCase):
12+
"""Test suite for InputFilter inheritance behavior."""
13+
14+
def test_basic_inheritance(self):
15+
"""Test basic inheritance where child class adds new fields."""
16+
17+
class UserInputFilter(InputFilter):
18+
username = field(required=True, validators=[IsStringValidator()])
19+
email = field(required=True, validators=[IsStringValidator()])
20+
21+
class ProfileInputFilter(UserInputFilter):
22+
bio = field(required=False, default='No bio')
23+
age = field(required=False, validators=[IsIntegerValidator()])
24+
25+
profile_filter = ProfileInputFilter()
26+
27+
# Check that all fields are present (inherited + new)
28+
expected_fields = {'username', 'email', 'bio', 'age'}
29+
actual_fields = set(profile_filter.fields.keys())
30+
self.assertEqual(expected_fields, actual_fields)
31+
32+
# Check field properties are preserved
33+
username_field = profile_filter.get_input('username')
34+
bio_field = profile_filter.get_input('bio')
35+
36+
self.assertTrue(username_field.required)
37+
self.assertFalse(bio_field.required)
38+
self.assertEqual(bio_field.default, 'No bio')
39+
40+
def test_field_overriding(self):
41+
"""Test that child classes can override parent field definitions."""
42+
43+
class BaseInputFilter(InputFilter):
44+
name = field(required=True, validators=[IsStringValidator()])
45+
value = field(required=False)
46+
47+
class EnhancedInputFilter(BaseInputFilter):
48+
# Override with additional filters and validators
49+
name = field(
50+
required=True,
51+
validators=[IsStringValidator(), LengthValidator(min_length=3)],
52+
filters=[StringTrimFilter(), ToLowerFilter()]
53+
)
54+
# Add new field
55+
description = field(required=False, default='')
56+
57+
enhanced_filter = EnhancedInputFilter()
58+
59+
# Check that overridden field has new properties
60+
name_field = enhanced_filter.get_input('name')
61+
self.assertEqual(len(name_field.validators), 2)
62+
self.assertEqual(len(name_field.filters), 2)
63+
64+
# Check that both inherited and new fields exist
65+
self.assertTrue(enhanced_filter.has('name'))
66+
self.assertTrue(enhanced_filter.has('value'))
67+
self.assertTrue(enhanced_filter.has('description'))
68+
69+
def test_multi_level_inheritance(self):
70+
"""Test inheritance through multiple levels."""
71+
72+
class BaseInputFilter(InputFilter):
73+
id = field(required=True, validators=[IsIntegerValidator()])
74+
75+
class UserInputFilter(BaseInputFilter):
76+
username = field(required=True, validators=[IsStringValidator()])
77+
email = field(required=True, validators=[IsStringValidator()])
78+
79+
class AdminInputFilter(UserInputFilter):
80+
role = field(required=True, default='admin')
81+
permissions = field(required=False, default=[])
82+
83+
admin_filter = AdminInputFilter()
84+
85+
# Check all fields from all inheritance levels are present
86+
expected_fields = {'id', 'username', 'email', 'role', 'permissions'}
87+
actual_fields = set(admin_filter.fields.keys())
88+
self.assertEqual(expected_fields, actual_fields)
89+
90+
# Verify field count
91+
self.assertEqual(admin_filter.count(), 5)
92+
93+
def test_validation_with_inherited_fields(self):
94+
"""Test that validation works correctly with inherited fields."""
95+
96+
class UserInputFilter(InputFilter):
97+
username = field(required=True, validators=[IsStringValidator()])
98+
email = field(required=True, validators=[IsStringValidator()])
99+
100+
class ProfileInputFilter(UserInputFilter):
101+
bio = field(required=False, default='Developer')
102+
age = field(required=False, validators=[IsIntegerValidator()])
103+
104+
profile_filter = ProfileInputFilter()
105+
106+
# Test successful validation with all field types
107+
test_data = {
108+
'username': 'john_doe',
109+
'email': '[email protected]',
110+
'bio': 'Senior Developer',
111+
'age': 30
112+
}
113+
114+
validated_data = profile_filter.validate_data(test_data)
115+
self.assertEqual(validated_data['username'], 'john_doe')
116+
self.assertEqual(validated_data['email'], '[email protected]')
117+
self.assertEqual(validated_data['bio'], 'Senior Developer')
118+
self.assertEqual(validated_data['age'], 30)
119+
120+
# Test validation with missing optional fields (should use defaults)
121+
minimal_data = {
122+
'username': 'jane_doe',
123+
'email': '[email protected]'
124+
}
125+
126+
validated_data = profile_filter.validate_data(minimal_data)
127+
self.assertEqual(validated_data['bio'], 'Developer') # default value
128+
129+
# Test validation failure for required inherited field
130+
invalid_data = {
131+
'email': '[email protected]',
132+
# missing required 'username'
133+
}
134+
135+
with self.assertRaises(ValidationError) as context:
136+
profile_filter.validate_data(invalid_data)
137+
138+
errors = context.exception.args[0]
139+
self.assertIn('username', errors)
140+
141+
def test_conditions_inheritance(self):
142+
"""Test that conditions can be inherited and work with inherited
143+
fields."""
144+
145+
class BaseInputFilter(InputFilter):
146+
field_a = field(required=False)
147+
field_b = field(required=False)
148+
149+
def __init__(self):
150+
super().__init__()
151+
self.add_condition(ExactlyOneOfCondition(['field_a', 'field_b']))
152+
153+
class ExtendedInputFilter(BaseInputFilter):
154+
field_c = field(required=False)
155+
156+
extended_filter = ExtendedInputFilter()
157+
158+
# Check that condition is inherited
159+
conditions = extended_filter.get_conditions()
160+
self.assertEqual(len(conditions), 1)
161+
self.assertIsInstance(conditions[0], ExactlyOneOfCondition)
162+
163+
# Test that inherited condition works
164+
valid_data = {'field_a': 'value1'}
165+
validated_data = extended_filter.validate_data(valid_data)
166+
self.assertEqual(validated_data['field_a'], 'value1')
167+
168+
# Test condition violation
169+
invalid_data = {'field_a': 'value1', 'field_b': 'value2'}
170+
with self.assertRaises(ValidationError):
171+
extended_filter.validate_data(invalid_data)
172+
173+
def test_global_filters_and_validators_inheritance(self):
174+
"""Test that global filters and validators are inherited."""
175+
176+
class BaseInputFilter(InputFilter):
177+
name = field(required=True)
178+
179+
def __init__(self):
180+
super().__init__()
181+
self.add_global_filter(StringTrimFilter())
182+
self.add_global_validator(IsStringValidator())
183+
184+
class ChildInputFilter(BaseInputFilter):
185+
description = field(required=False, default='No description')
186+
187+
child_filter = ChildInputFilter()
188+
189+
# Check that global filters and validators are inherited
190+
global_filters = child_filter.get_global_filters()
191+
global_validators = child_filter.get_global_validators()
192+
193+
self.assertEqual(len(global_filters), 1)
194+
self.assertEqual(len(global_validators), 1)
195+
self.assertIsInstance(global_filters[0], StringTrimFilter)
196+
self.assertIsInstance(global_validators[0], IsStringValidator)
197+
198+
# Test that global filter works on inherited fields
199+
test_data = {'name': ' john ', 'description': ' A developer '}
200+
validated_data = child_filter.validate_data(test_data)
201+
self.assertEqual(validated_data['name'], 'john') # trimmed by global filter
202+
self.assertEqual(validated_data['description'], 'A developer') # also trimmed
203+
204+
def test_inheritance_with_programmatic_fields(self):
205+
"""Test inheritance when parent uses programmatic field addition."""
206+
207+
class ProgrammaticInputFilter(InputFilter):
208+
def __init__(self):
209+
super().__init__()
210+
self.add('username', required=True, validators=[IsStringValidator()])
211+
self.add('email', required=True, validators=[IsStringValidator()])
212+
213+
class DeclarativeChildInputFilter(ProgrammaticInputFilter):
214+
bio = field(required=False, default='No bio')
215+
age = field(required=False, validators=[IsIntegerValidator()])
216+
217+
child_filter = DeclarativeChildInputFilter()
218+
219+
# Check that both programmatic and declarative fields exist
220+
expected_fields = {'username', 'email', 'bio', 'age'}
221+
actual_fields = set(child_filter.fields.keys())
222+
self.assertEqual(expected_fields, actual_fields)
223+
224+
# Test validation works for both types
225+
test_data = {
226+
'username': 'test_user',
227+
'email': '[email protected]',
228+
'bio': 'Developer',
229+
'age': 28
230+
}
231+
232+
validated_data = child_filter.validate_data(test_data)
233+
self.assertEqual(len(validated_data), 4)
234+
self.assertEqual(validated_data['username'], 'test_user')
235+
self.assertEqual(validated_data['bio'], 'Developer')
236+
237+
def test_deep_inheritance_chain(self):
238+
"""Test a deep inheritance chain to ensure all levels work
239+
correctly."""
240+
241+
class Level1InputFilter(InputFilter):
242+
level1_field = field(required=False, default='level1')
243+
244+
class Level2InputFilter(Level1InputFilter):
245+
level2_field = field(required=False, default='level2')
246+
247+
class Level3InputFilter(Level2InputFilter):
248+
level3_field = field(required=False, default='level3')
249+
250+
class Level4InputFilter(Level3InputFilter):
251+
level4_field = field(required=False, default='level4')
252+
253+
deep_filter = Level4InputFilter()
254+
255+
# Check all fields from all levels are present
256+
expected_fields = {'level1_field', 'level2_field', 'level3_field', 'level4_field'}
257+
actual_fields = set(deep_filter.fields.keys())
258+
self.assertEqual(expected_fields, actual_fields)
259+
260+
# Test validation with all defaults
261+
validated_data = deep_filter.validate_data({})
262+
self.assertEqual(validated_data['level1_field'], 'level1')
263+
self.assertEqual(validated_data['level2_field'], 'level2')
264+
self.assertEqual(validated_data['level3_field'], 'level3')
265+
self.assertEqual(validated_data['level4_field'], 'level4')
266+
267+
def test_inheritance_does_not_modify_parent_class(self):
268+
"""Test that creating child classes doesn't modify parent class."""
269+
270+
class ParentInputFilter(InputFilter):
271+
parent_field = field(required=True)
272+
273+
# Create instance of parent before child class definition
274+
parent_instance_before = ParentInputFilter()
275+
276+
class ChildInputFilter(ParentInputFilter):
277+
child_field = field(required=True)
278+
279+
# Create instance of parent after child class definition
280+
parent_instance_after = ParentInputFilter()
281+
child_instance = ChildInputFilter()
282+
283+
# Parent instances should only have parent fields
284+
self.assertEqual(set(parent_instance_before.fields.keys()), {'parent_field'})
285+
self.assertEqual(set(parent_instance_after.fields.keys()), {'parent_field'})
286+
287+
# Child instance should have both fields
288+
self.assertEqual(set(child_instance.fields.keys()), {'parent_field', 'child_field'})
289+
290+
291+
if __name__ == '__main__':
292+
unittest.main()

0 commit comments

Comments
 (0)