Skip to content

Commit 683c5a3

Browse files
authored
allow serialization functions to upgrade warnings to exceptions (#1258)
Signed-off-by: Lance Drane <[email protected]> Signed-off-by: Lance-Drane <[email protected]>
1 parent ee06335 commit 683c5a3

File tree

8 files changed

+148
-31
lines changed

8 files changed

+148
-31
lines changed

.mypy-stubtest-allowlist

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22
pydantic_core._pydantic_core.PydanticUndefinedType.new
33
# As per #1240, from_json has custom logic to coverage the `cache_strings` kwarg
44
pydantic_core._pydantic_core.from_json
5+
# the `warnings` kwarg for SchemaSerializer functions has custom logic
6+
pydantic_core._pydantic_core.SchemaSerializer.to_python
7+
pydantic_core._pydantic_core.SchemaSerializer.to_json

python/pydantic_core/_pydantic_core.pyi

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ class SchemaSerializer:
264264
exclude_defaults: bool = False,
265265
exclude_none: bool = False,
266266
round_trip: bool = False,
267-
warnings: bool = True,
267+
warnings: bool | Literal['none', 'warn', 'error'] = True,
268268
fallback: Callable[[Any], Any] | None = None,
269269
serialize_as_any: bool = False,
270270
context: dict[str, Any] | None = None,
@@ -284,7 +284,8 @@ class SchemaSerializer:
284284
exclude_defaults: Whether to exclude fields that are equal to their default value.
285285
exclude_none: Whether to exclude fields that have a value of `None`.
286286
round_trip: Whether to enable serialization and validation round-trip support.
287-
warnings: Whether to log warnings when invalid fields are encountered.
287+
warnings: How to handle invalid fields. False/"none" ignores them, True/"warn" logs errors,
288+
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
288289
fallback: A function to call when an unknown value is encountered,
289290
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
290291
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
@@ -309,7 +310,7 @@ class SchemaSerializer:
309310
exclude_defaults: bool = False,
310311
exclude_none: bool = False,
311312
round_trip: bool = False,
312-
warnings: bool = True,
313+
warnings: bool | Literal['none', 'warn', 'error'] = True,
313314
fallback: Callable[[Any], Any] | None = None,
314315
serialize_as_any: bool = False,
315316
context: dict[str, Any] | None = None,
@@ -328,7 +329,8 @@ class SchemaSerializer:
328329
exclude_defaults: Whether to exclude fields that are equal to their default value.
329330
exclude_none: Whether to exclude fields that have a value of `None`.
330331
round_trip: Whether to enable serialization and validation round-trip support.
331-
warnings: Whether to log warnings when invalid fields are encountered.
332+
warnings: How to handle invalid fields. False/"none" ignores them, True/"warn" logs errors,
333+
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
332334
fallback: A function to call when an unknown value is encountered,
333335
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
334336
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub use errors::{
3434
};
3535
pub use serializers::{
3636
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
37+
WarningsArg,
3738
};
3839
pub use validators::{validate_core_schema, PySome, SchemaValidator};
3940

src/serializers/extra.rs

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use std::cell::RefCell;
22
use std::fmt;
33

4-
use pyo3::exceptions::PyValueError;
4+
use pyo3::exceptions::{PyTypeError, PyValueError};
55
use pyo3::intern;
66
use pyo3::prelude::*;
7+
use pyo3::types::PyBool;
78

89
use serde::ser::Error;
910

@@ -14,6 +15,7 @@ use crate::recursion_guard::ContainsRecursionState;
1415
use crate::recursion_guard::RecursionError;
1516
use crate::recursion_guard::RecursionGuard;
1617
use crate::recursion_guard::RecursionState;
18+
use crate::PydanticSerializationError;
1719

1820
/// this is ugly, would be much better if extra could be stored in `SerializationState`
1921
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
@@ -64,7 +66,7 @@ impl DuckTypingSerMode {
6466

6567
impl SerializationState {
6668
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
67-
let warnings = CollectWarnings::new(false);
69+
let warnings = CollectWarnings::new(WarningsMode::None);
6870
let rec_guard = SerRecursionState::default();
6971
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
7072
Ok(Self {
@@ -325,23 +327,61 @@ impl ToPyObject for SerMode {
325327
}
326328
}
327329

330+
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
331+
pub enum WarningsMode {
332+
None,
333+
Warn,
334+
Error,
335+
}
336+
337+
impl<'py> FromPyObject<'py> for WarningsMode {
338+
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<WarningsMode> {
339+
if let Ok(bool_mode) = ob.downcast::<PyBool>() {
340+
Ok(bool_mode.is_true().into())
341+
} else if let Ok(str_mode) = ob.extract::<&str>() {
342+
match str_mode {
343+
"none" => Ok(Self::None),
344+
"warn" => Ok(Self::Warn),
345+
"error" => Ok(Self::Error),
346+
_ => Err(PyValueError::new_err(
347+
"Invalid warnings parameter, should be `'none'`, `'warn'`, `'error'` or a `bool`",
348+
)),
349+
}
350+
} else {
351+
Err(PyTypeError::new_err(
352+
"Invalid warnings parameter, should be `'none'`, `'warn'`, `'error'` or a `bool`",
353+
))
354+
}
355+
}
356+
}
357+
358+
impl From<bool> for WarningsMode {
359+
fn from(mode: bool) -> Self {
360+
if mode {
361+
Self::Warn
362+
} else {
363+
Self::None
364+
}
365+
}
366+
}
367+
328368
#[derive(Clone)]
329369
#[cfg_attr(debug_assertions, derive(Debug))]
330370
pub(crate) struct CollectWarnings {
331-
active: bool,
371+
mode: WarningsMode,
332372
warnings: RefCell<Option<Vec<String>>>,
333373
}
334374

335375
impl CollectWarnings {
336-
pub(crate) fn new(active: bool) -> Self {
376+
pub(crate) fn new(mode: WarningsMode) -> Self {
337377
Self {
338-
active,
378+
mode,
339379
warnings: RefCell::new(None),
340380
}
341381
}
342382

343383
pub fn custom_warning(&self, warning: String) {
344-
if self.active {
384+
if self.mode != WarningsMode::None {
345385
self.add_warning(warning);
346386
}
347387
}
@@ -379,7 +419,7 @@ impl CollectWarnings {
379419
}
380420

381421
fn fallback_warning(&self, field_type: &str, value: &Bound<'_, PyAny>) {
382-
if self.active {
422+
if self.mode != WarningsMode::None {
383423
let type_name = value
384424
.get_type()
385425
.qualname()
@@ -400,17 +440,20 @@ impl CollectWarnings {
400440
}
401441

402442
pub fn final_check(&self, py: Python) -> PyResult<()> {
403-
if self.active {
404-
match *self.warnings.borrow() {
405-
Some(ref warnings) => {
406-
let message = format!("Pydantic serializer warnings:\n {}", warnings.join("\n "));
443+
if self.mode == WarningsMode::None {
444+
return Ok(());
445+
}
446+
match *self.warnings.borrow() {
447+
Some(ref warnings) => {
448+
let message = format!("Pydantic serializer warnings:\n {}", warnings.join("\n "));
449+
if self.mode == WarningsMode::Warn {
407450
let user_warning_type = py.import_bound("builtins")?.getattr("UserWarning")?;
408451
PyErr::warn_bound(py, &user_warning_type, &message, 0)
452+
} else {
453+
Err(PydanticSerializationError::new_err(message))
409454
}
410-
_ => Ok(()),
411455
}
412-
} else {
413-
Ok(())
456+
_ => Ok(()),
414457
}
415458
}
416459
}

src/serializers/mod.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;
1010

1111
use config::SerializationConfig;
1212
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
13-
use extra::{CollectWarnings, SerRecursionState};
13+
use extra::{CollectWarnings, SerRecursionState, WarningsMode};
1414
pub(crate) use extra::{DuckTypingSerMode, Extra, SerMode, SerializationState};
1515
pub use shared::CombinedSerializer;
1616
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
@@ -27,6 +27,12 @@ pub mod ser;
2727
mod shared;
2828
mod type_serializers;
2929

30+
#[derive(FromPyObject)]
31+
pub enum WarningsArg {
32+
Bool(bool),
33+
Literal(WarningsMode),
34+
}
35+
3036
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
3137
#[derive(Debug)]
3238
pub struct SchemaSerializer {
@@ -98,7 +104,7 @@ impl SchemaSerializer {
98104

99105
#[allow(clippy::too_many_arguments)]
100106
#[pyo3(signature = (value, *, mode = None, include = None, exclude = None, by_alias = true,
101-
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true,
107+
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = WarningsArg::Bool(true),
102108
fallback = None, serialize_as_any = false, context = None))]
103109
pub fn to_python(
104110
&self,
@@ -112,13 +118,17 @@ impl SchemaSerializer {
112118
exclude_defaults: bool,
113119
exclude_none: bool,
114120
round_trip: bool,
115-
warnings: bool,
121+
warnings: WarningsArg,
116122
fallback: Option<&Bound<'_, PyAny>>,
117123
serialize_as_any: bool,
118124
context: Option<&Bound<'_, PyAny>>,
119125
) -> PyResult<PyObject> {
120126
let mode: SerMode = mode.into();
121-
let warnings = CollectWarnings::new(warnings);
127+
let warnings_mode = match warnings {
128+
WarningsArg::Bool(b) => b.into(),
129+
WarningsArg::Literal(mode) => mode,
130+
};
131+
let warnings = CollectWarnings::new(warnings_mode);
122132
let rec_guard = SerRecursionState::default();
123133
let duck_typing_ser_mode = DuckTypingSerMode::from_bool(serialize_as_any);
124134
let extra = self.build_extra(
@@ -143,7 +153,7 @@ impl SchemaSerializer {
143153

144154
#[allow(clippy::too_many_arguments)]
145155
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
146-
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true,
156+
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = WarningsArg::Bool(true),
147157
fallback = None, serialize_as_any = false, context = None))]
148158
pub fn to_json(
149159
&self,
@@ -157,12 +167,16 @@ impl SchemaSerializer {
157167
exclude_defaults: bool,
158168
exclude_none: bool,
159169
round_trip: bool,
160-
warnings: bool,
170+
warnings: WarningsArg,
161171
fallback: Option<&Bound<'_, PyAny>>,
162172
serialize_as_any: bool,
163173
context: Option<&Bound<'_, PyAny>>,
164174
) -> PyResult<PyObject> {
165-
let warnings = CollectWarnings::new(warnings);
175+
let warnings_mode = match warnings {
176+
WarningsArg::Bool(b) => b.into(),
177+
WarningsArg::Literal(mode) => mode,
178+
};
179+
let warnings = CollectWarnings::new(warnings_mode);
166180
let rec_guard = SerRecursionState::default();
167181
let duck_typing_ser_mode = DuckTypingSerMode::from_bool(serialize_as_any);
168182
let extra = self.build_extra(

tests/serializers/test_list_tuple.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from pydantic_core import SchemaError, SchemaSerializer, core_schema, validate_core_schema
7+
from pydantic_core import PydanticSerializationError, SchemaError, SchemaSerializer, core_schema, validate_core_schema
88

99

1010
def test_list_any():
@@ -52,6 +52,16 @@ def test_list_str_fallback():
5252
' Expected `str` but got `int` - serialized value may not be as expected\n'
5353
' Expected `str` but got `int` - serialized value may not be as expected'
5454
]
55+
with pytest.raises(PydanticSerializationError) as warning_ex:
56+
v.to_json([1, 2, 3], warnings='error')
57+
assert str(warning_ex.value) == ''.join(
58+
[
59+
'Pydantic serializer warnings:\n'
60+
' Expected `str` but got `int` - serialized value may not be as expected\n'
61+
' Expected `str` but got `int` - serialized value may not be as expected\n'
62+
' Expected `str` but got `int` - serialized value may not be as expected'
63+
]
64+
)
5565

5666

5767
def test_tuple_any():

tests/serializers/test_string.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from pydantic_core import SchemaSerializer, core_schema
6+
from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema
77

88

99
def test_str():
@@ -34,13 +34,44 @@ def test_str_fallback():
3434
assert s.to_python(123, mode='json') == 123
3535
with pytest.warns(UserWarning, match='Expected `str` but got `int` - serialized value may not be as expected'):
3636
assert s.to_json(123) == b'123'
37+
with pytest.warns(UserWarning, match='Expected `str` but got `int` - serialized value may not be as expected'):
38+
assert s.to_python(123, warnings='warn') == 123
39+
with pytest.warns(UserWarning, match='Expected `str` but got `int` - serialized value may not be as expected'):
40+
assert s.to_python(123, mode='json', warnings='warn') == 123
41+
with pytest.warns(UserWarning, match='Expected `str` but got `int` - serialized value may not be as expected'):
42+
assert s.to_json(123, warnings='warn') == b'123'
43+
with pytest.warns(UserWarning, match='Expected `str` but got `int` - serialized value may not be as expected'):
44+
assert s.to_python(123, warnings=True) == 123
45+
with pytest.warns(UserWarning, match='Expected `str` but got `int` - serialized value may not be as expected'):
46+
assert s.to_python(123, mode='json', warnings=True) == 123
47+
with pytest.warns(UserWarning, match='Expected `str` but got `int` - serialized value may not be as expected'):
48+
assert s.to_json(123, warnings=True) == b'123'
3749

3850

3951
def test_str_no_warnings():
4052
s = SchemaSerializer(core_schema.str_schema())
4153
assert s.to_python(123, warnings=False) == 123
54+
assert s.to_python(123, warnings='none') == 123
4255
assert s.to_python(123, mode='json', warnings=False) == 123
56+
assert s.to_python(123, mode='json', warnings='none') == 123
4357
assert s.to_json(123, warnings=False) == b'123'
58+
assert s.to_json(123, warnings='none') == b'123'
59+
60+
61+
def test_str_errors():
62+
s = SchemaSerializer(core_schema.str_schema())
63+
with pytest.raises(
64+
PydanticSerializationError, match='Expected `str` but got `int` - serialized value may not be as expected'
65+
):
66+
assert s.to_python(123, warnings='error') == 123
67+
with pytest.raises(
68+
PydanticSerializationError, match='Expected `str` but got `int` - serialized value may not be as expected'
69+
):
70+
assert s.to_python(123, mode='json', warnings='error') == 123
71+
with pytest.raises(
72+
PydanticSerializationError, match='Expected `str` but got `int` - serialized value may not be as expected'
73+
):
74+
assert s.to_json(123, warnings='error') == b'123'
4475

4576

4677
class StrSubclass(str):

tests/test.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#[cfg(test)]
22
mod tests {
3-
use _pydantic_core::{SchemaSerializer, SchemaValidator};
3+
use _pydantic_core::{SchemaSerializer, SchemaValidator, WarningsArg};
44
use pyo3::prelude::*;
55
use pyo3::types::PyDict;
66

@@ -85,7 +85,20 @@ a = A()
8585
let serialized = SchemaSerializer::py_new(schema, None)
8686
.unwrap()
8787
.to_json(
88-
py, &a, None, None, None, true, false, false, false, false, true, None, false, None,
88+
py,
89+
&a,
90+
None,
91+
None,
92+
None,
93+
true,
94+
false,
95+
false,
96+
false,
97+
false,
98+
WarningsArg::Bool(true),
99+
None,
100+
false,
101+
None,
89102
)
90103
.unwrap();
91104
let serialized: &[u8] = serialized.extract(py).unwrap();
@@ -186,7 +199,7 @@ dump_json_input_2 = {'a': 'something'}
186199
false,
187200
false,
188201
false,
189-
false,
202+
WarningsArg::Bool(false),
190203
None,
191204
false,
192205
None,
@@ -207,7 +220,7 @@ dump_json_input_2 = {'a': 'something'}
207220
false,
208221
false,
209222
false,
210-
false,
223+
WarningsArg::Bool(false),
211224
None,
212225
false,
213226
None,

0 commit comments

Comments
 (0)