Skip to content

Commit bd78e00

Browse files
committed
🧪 Tests for Option and Options
xtl.config.options:Options - Refactored _custom_validation() to enhance code reusability - Fixed a bug where the __setattr__() method would mutate an attribute's value if a ValidationError was caught and not raised tests.config.test_options - Some tests for Option and Options
1 parent 2f6d5a6 commit bd78e00

File tree

2 files changed

+167
-42
lines changed

2 files changed

+167
-42
lines changed

src/xtl/config/options.py

Lines changed: 96 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import partial
33
import os
44
import re
5-
from typing import Any, Callable
5+
from typing import Any, Callable, Optional
66
from typing_extensions import Self
77

88
from pydantic import (BaseModel, ConfigDict, Field, PrivateAttr, model_validator,
@@ -64,6 +64,9 @@ def Option(
6464
"""
6565
Create a field with custom validation, serialization and metadata.
6666
"""
67+
if default is PydanticUndefined and default_factory is _Unset:
68+
raise ValueError('Either \'default\' or \'default_factory\' must be provided')
69+
6770
if extra is _Unset:
6871
extra = {}
6972

@@ -153,6 +156,58 @@ def _get_custom_validators(cls) -> dict[str, dict[str, list]]:
153156
custom_validators[name]['after'].append(validator)
154157
return custom_validators
155158

159+
@staticmethod
160+
def _apply_validators(name: str, value: Any,
161+
validators: list[AfterValidator | BeforeValidator],
162+
errors: list[InitErrorDetails] = None) \
163+
-> tuple[Any, list[InitErrorDetails]]:
164+
if errors is None:
165+
errors = []
166+
new_errors = []
167+
168+
for validator in validators:
169+
try:
170+
value = validator.func(value)
171+
except ValueError as e:
172+
new_errors.append(InitErrorDetails(type='value_error', loc=(name,),
173+
input=value, ctx={'error': e}))
174+
175+
if new_errors:
176+
return _Unset, errors + new_errors
177+
else:
178+
return value, errors
179+
180+
@classmethod
181+
def _validate_before(cls, name: str, value: Any,
182+
validators: dict[str, list[BeforeValidator | AfterValidator]],
183+
errors: list[InitErrorDetails] = None,
184+
parse_env: bool = False) -> tuple[Any, list[InitErrorDetails]]:
185+
if errors is None:
186+
errors = []
187+
188+
if parse_env:
189+
value = cls._get_envvar(value)
190+
191+
# Apply validators
192+
value, new_errors = cls._apply_validators(name=name, value=value,
193+
validators=validators['before'],
194+
errors=errors)
195+
return value, new_errors
196+
197+
@classmethod
198+
def _validate_after(cls, name: str, value: Any,
199+
validators: dict[str, list[BeforeValidator | AfterValidator]],
200+
errors: list[InitErrorDetails] = None) \
201+
-> tuple[Any, list[InitErrorDetails]]:
202+
if errors is None:
203+
errors = []
204+
205+
# Apply validators
206+
value, new_errors = cls._apply_validators(name=name, value=value,
207+
validators=validators['after'],
208+
errors=errors)
209+
return value, new_errors
210+
156211
@model_validator(mode='wrap')
157212
@classmethod
158213
def _custom_validation(cls, data: Any, handler: ModelWrapValidatorHandler[Self]) \
@@ -186,22 +241,19 @@ def _custom_validation(cls, data: Any, handler: ModelWrapValidatorHandler[Self])
186241
# Apply before validators for raw data only
187242
if mode == 'dict':
188243
for name, value in data.items():
244+
new_errors = []
189245
if name not in validators:
190246
continue
191-
# Replace environment variables in the value
192-
if parse_env:
193-
value = cls._get_envvar(value)
247+
value, new_errors = cls._validate_before(name=name, value=value,
248+
validators=validators[name],
249+
errors=new_errors,
250+
parse_env=parse_env)
251+
252+
if not new_errors:
194253
# Update value
195254
data[name] = value
196-
for validator in validators[name]['before']:
197-
try:
198-
# Call the validator function
199-
value = validator.func(value)
200-
# Update value
201-
data[name] = value
202-
except ValueError as e:
203-
errors.append(InitErrorDetails(type='value_error', loc=(name,),
204-
input=value, ctx={'error': e}))
255+
256+
errors.extend(new_errors)
205257
# Check and raise validation errors
206258
if errors:
207259
raise ValidationError.from_exception_data(title='before_validators',
@@ -212,46 +264,48 @@ def _custom_validation(cls, data: Any, handler: ModelWrapValidatorHandler[Self])
212264

213265
# Apply after validators
214266
for name, value in validated_self.model_dump().items():
267+
new_errors = []
215268
if name not in validators:
216269
continue
217-
for validator in validators[name]['after']:
218-
try:
219-
# Call the validator function
220-
value = validator.func(value)
221-
# Update value
222-
if mode == 'dict':
223-
setattr(validated_self, name, value)
224-
elif mode == 'pydantic':
225-
# Note: Workaround to prevent infinite recursion when using
226-
# setattr
227-
validated_self.__dict__[name] = value
228-
except ValueError as e:
229-
errors.append(InitErrorDetails(type='value_error', loc=(name,),
230-
input=value, ctx={'error': e}))
270+
value, new_errors = cls._validate_after(name=name, value=value,
271+
validators=validators[name],
272+
errors=new_errors)
273+
if not new_errors:
274+
# Note: Workaround to prevent infinite recursion when using
275+
# setattr
276+
validated_self.__dict__[name] = value
277+
errors.extend(new_errors)
231278
# Check and raise validation errors
232279
if errors:
233280
raise ValidationError.from_exception_data(title='after_validators',
234281
line_errors=errors)
235282
return validated_self
236283

237284
def __setattr__(self, name: str, value: Any) -> None:
285+
# Check if the attribute is a pydantic field
286+
if name not in self.__pydantic_fields__:
287+
super().__setattr__(name, value)
288+
return
289+
290+
# Apply before validators
291+
validators = self._get_custom_validators()
238292
errors = []
239-
# Check if the attribute is a field
240-
if name in self.__dict__:
241-
if name in self._get_custom_validators():
242-
# Replace environment variables in the value
243-
if self._parse_env:
244-
value = self._get_envvar(value)
245-
# Apply before validators
246-
for validator in self._get_custom_validators()[name]['before']:
247-
try:
248-
value = validator.func(value)
249-
except ValueError as e:
250-
errors.append(InitErrorDetails(type='value_error', loc=(name,),
251-
input=value, ctx={'error': e}))
293+
if name in validators:
294+
value, errors = self._validate_before(name=name, value=value,
295+
validators=validators[name],
296+
parse_env=self._parse_env)
252297
# Check and raise validation errors
253298
if errors:
254299
raise ValidationError.from_exception_data(title='before_validators',
255300
line_errors=errors)
256-
# Continue with the default behavior
257-
super().__setattr__(name, value)
301+
302+
try:
303+
# Validate the model with the new value
304+
data = self.model_dump()
305+
data.update({name: value})
306+
self.model_validate(data)
307+
# Update the attribute if validation is successful
308+
super().__setattr__(name, value)
309+
except ValidationError as e:
310+
raise e
311+

tests/config/test_options.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
3+
from pydantic import BeforeValidator, AfterValidator, Field, ValidationError
4+
5+
from xtl.config.options import Option, Options
6+
from xtl.config.validators import cast_as, validate_length
7+
8+
9+
class TestOption:
10+
11+
def test_no_defaults(self):
12+
with pytest.raises(ValueError, match='Either \'default\' or \'default_factory\' '
13+
'must be provided'):
14+
o = Option()
15+
16+
def test_validators(self):
17+
o = Option(default=1)
18+
assert not o.json_schema_extra
19+
20+
o = Option(default=1, cast_as=str, length=3)
21+
assert o.json_schema_extra
22+
assert o.json_schema_extra['validators']
23+
# Check the signature of the validator functions
24+
# validator.func -> partial function
25+
# partial.func -> original function
26+
assert o.json_schema_extra['validators'][0].func.func == validate_length
27+
assert o.json_schema_extra['validators'][1].func.func == cast_as
28+
29+
30+
class TestOptions:
31+
32+
class MyModel(Options):
33+
name: str = Option(default=None, choices=('Alice', 'Bob', 'Charlie'))
34+
age: int = Option(default=None, gt=0)
35+
double_this: int = Option(default=1, cast_as=lambda x: 2 * int(x))
36+
# pydantic.Field can also be used along with Option
37+
field: float = Field()
38+
39+
def test_init(self):
40+
m = self.MyModel(name='Alice', age=2, double_this='3', field=1.5)
41+
assert m.model_dump() == {
42+
'name': 'Alice',
43+
'age': 2,
44+
'double_this': 6,
45+
'field': 1.5
46+
}
47+
48+
def test_assignment(self):
49+
m = self.MyModel(name='Alice', age=2, double_this=3, field=1.5)
50+
m.name = 'Bob'
51+
m.age = 3
52+
m.double_this = 4
53+
m.field = 2.5
54+
assert m.model_dump() == {
55+
'name': 'Bob',
56+
'age': 3,
57+
'double_this': 8,
58+
'field': 2.5
59+
}
60+
61+
# Test immutability when validation errors occur
62+
with pytest.raises(ValidationError, match='Value is not in choices'):
63+
m.name = 'Dave'
64+
assert m.model_dump() == {
65+
'name': 'Bob',
66+
'age': 3,
67+
'double_this': 8,
68+
'field': 2.5
69+
}
70+
71+

0 commit comments

Comments
 (0)