Skip to content

Commit a03210d

Browse files
andreslisztAndres
authored andcommitted
Support exclude_if callable at field level
1 parent 83ff1cf commit a03210d

File tree

9 files changed

+129
-37
lines changed

9 files changed

+129
-37
lines changed

python/pydantic_core/core_schema.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,7 +2817,8 @@ class TypedDictField(TypedDict, total=False):
28172817
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
28182818
serialization_alias: str
28192819
serialization_exclude: bool # default: False
2820-
metadata: Dict[str, Any]
2820+
serialization_exclude_if: Callable[[Any], bool] # default None
2821+
metadata: Any
28212822

28222823

28232824
def typed_dict_field(
@@ -2827,7 +2828,8 @@ def typed_dict_field(
28272828
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
28282829
serialization_alias: str | None = None,
28292830
serialization_exclude: bool | None = None,
2830-
metadata: Dict[str, Any] | None = None,
2831+
serialization_exclude_if: Callable[[Any], bool] | None = None,
2832+
metadata: Any = None,
28312833
) -> TypedDictField:
28322834
"""
28332835
Returns a schema that matches a typed dict field, e.g.:
@@ -2844,6 +2846,7 @@ def typed_dict_field(
28442846
validation_alias: The alias(es) to use to find the field in the validation data
28452847
serialization_alias: The alias to use as a key when serializing
28462848
serialization_exclude: Whether to exclude the field when serializing
2849+
serialization_exclude_if: A callable that determines whether to exclude the field when serializing based on its value.
28472850
metadata: Any other information you want to include with the schema, not used by pydantic-core
28482851
"""
28492852
return _dict_not_none(
@@ -2853,6 +2856,7 @@ def typed_dict_field(
28532856
validation_alias=validation_alias,
28542857
serialization_alias=serialization_alias,
28552858
serialization_exclude=serialization_exclude,
2859+
serialization_exclude_if=serialization_exclude_if,
28562860
metadata=metadata,
28572861
)
28582862

@@ -2943,6 +2947,7 @@ class ModelField(TypedDict, total=False):
29432947
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
29442948
serialization_alias: str
29452949
serialization_exclude: bool # default: False
2950+
serialization_exclude_if: Callable[[Any], bool] # default: None
29462951
frozen: bool
29472952
metadata: Dict[str, Any]
29482953

@@ -2953,6 +2958,7 @@ def model_field(
29532958
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
29542959
serialization_alias: str | None = None,
29552960
serialization_exclude: bool | None = None,
2961+
exclude_if: Callable[[Any], bool] | None = None,
29562962
frozen: bool | None = None,
29572963
metadata: Dict[str, Any] | None = None,
29582964
) -> ModelField:
@@ -2970,6 +2976,7 @@ def model_field(
29702976
validation_alias: The alias(es) to use to find the field in the validation data
29712977
serialization_alias: The alias to use as a key when serializing
29722978
serialization_exclude: Whether to exclude the field when serializing
2979+
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
29732980
frozen: Whether the field is frozen
29742981
metadata: Any other information you want to include with the schema, not used by pydantic-core
29752982
"""
@@ -2979,6 +2986,7 @@ def model_field(
29792986
validation_alias=validation_alias,
29802987
serialization_alias=serialization_alias,
29812988
serialization_exclude=serialization_exclude,
2989+
exclude_if=exclude_if,
29822990
frozen=frozen,
29832991
metadata=metadata,
29842992
)
@@ -3171,7 +3179,8 @@ class DataclassField(TypedDict, total=False):
31713179
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
31723180
serialization_alias: str
31733181
serialization_exclude: bool # default: False
3174-
metadata: Dict[str, Any]
3182+
serialization_exclude_if: Callable[[Any], bool] # default: None
3183+
metadata: Any
31753184

31763185

31773186
def dataclass_field(
@@ -3184,7 +3193,8 @@ def dataclass_field(
31843193
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
31853194
serialization_alias: str | None = None,
31863195
serialization_exclude: bool | None = None,
3187-
metadata: Dict[str, Any] | None = None,
3196+
serialization_exclude_if: Callable[[Any], bool] | None = None,
3197+
metadata: Any = None,
31883198
frozen: bool | None = None,
31893199
) -> DataclassField:
31903200
"""
@@ -3210,6 +3220,7 @@ def dataclass_field(
32103220
validation_alias: The alias(es) to use to find the field in the validation data
32113221
serialization_alias: The alias to use as a key when serializing
32123222
serialization_exclude: Whether to exclude the field when serializing
3223+
serialization_exclude_if: A callable that determines whether to exclude the field when serializing based on its value.
32133224
metadata: Any other information you want to include with the schema, not used by pydantic-core
32143225
frozen: Whether the field is frozen
32153226
"""
@@ -3223,6 +3234,7 @@ def dataclass_field(
32233234
validation_alias=validation_alias,
32243235
serialization_alias=serialization_alias,
32253236
serialization_exclude=serialization_exclude,
3237+
serialization_exclude_if=serialization_exclude_if,
32263238
metadata=metadata,
32273239
frozen=frozen,
32283240
)

src/serializers/fields.rs

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub(super) struct SerField {
2929
// None serializer means exclude
3030
pub serializer: Option<CombinedSerializer>,
3131
pub required: bool,
32+
pub exclude_if: Option<Py<PyAny>>,
3233
}
3334

3435
impl_py_gc_traverse!(SerField { serializer });
@@ -40,6 +41,7 @@ impl SerField {
4041
alias: Option<String>,
4142
serializer: Option<CombinedSerializer>,
4243
required: bool,
44+
exclude_if: Option<Py<PyAny>>,
4345
) -> Self {
4446
let alias_py = alias
4547
.as_ref()
@@ -50,6 +52,7 @@ impl SerField {
5052
alias_py,
5153
serializer,
5254
required,
55+
exclude_if,
5356
}
5457
}
5558

@@ -72,6 +75,18 @@ impl SerField {
7275
}
7376
}
7477

78+
fn exclude_if(exclude_if_callable: &Option<Py<PyAny>>, value: &Bound<'_, PyAny>) -> PyResult<bool> {
79+
if let Some(exclude_if_callable) = exclude_if_callable {
80+
let py = value.py();
81+
let result = exclude_if_callable.call1(py, (value,))?;
82+
let exclude = result.extract::<bool>(py)?;
83+
if exclude {
84+
return Ok(true);
85+
}
86+
}
87+
Ok(false)
88+
}
89+
7590
fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
7691
if extra.exclude_defaults {
7792
if let Some(default) = serializer.get_default(value.py())? {
@@ -176,16 +191,16 @@ impl GeneralFieldsSerializer {
176191
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
177192
if let Some(field) = op_field {
178193
if let Some(ref serializer) = field.serializer {
179-
if !exclude_default(&value, &field_extra, serializer)? {
180-
let value = serializer.to_python(
181-
&value,
182-
next_include.as_ref(),
183-
next_exclude.as_ref(),
184-
&field_extra,
185-
)?;
186-
let output_key = field.get_key_py(output_dict.py(), &field_extra);
187-
output_dict.set_item(output_key, value)?;
194+
if exclude_default(&value, &field_extra, serializer)? {
195+
continue;
188196
}
197+
if exclude_if(&field.exclude_if, &value)? {
198+
continue;
199+
}
200+
let value =
201+
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?;
202+
let output_key = field.get_key_py(output_dict.py(), &field_extra);
203+
output_dict.set_item(output_key, value)?;
189204
}
190205

191206
if field.required {
@@ -263,17 +278,21 @@ impl GeneralFieldsSerializer {
263278
if let Some((next_include, next_exclude)) = filter {
264279
if let Some(field) = self.fields.get(key_str) {
265280
if let Some(ref serializer) = field.serializer {
266-
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
267-
let s = PydanticSerializer::new(
268-
&value,
269-
serializer,
270-
next_include.as_ref(),
271-
next_exclude.as_ref(),
272-
&field_extra,
273-
);
274-
let output_key = field.get_key_json(key_str, &field_extra);
275-
map.serialize_entry(&output_key, &s)?;
281+
if exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
282+
continue;
283+
}
284+
if exclude_if(&field.exclude_if, &value).map_err(py_err_se_err)? {
285+
continue;
276286
}
287+
let s = PydanticSerializer::new(
288+
&value,
289+
serializer,
290+
next_include.as_ref(),
291+
next_exclude.as_ref(),
292+
&field_extra,
293+
);
294+
let output_key = field.get_key_json(key_str, &field_extra);
295+
map.serialize_entry(&output_key, &s)?;
277296
}
278297
} else if self.mode == FieldsMode::TypedDictAllow {
279298
let output_key = infer_json_key(&key, &field_extra).map_err(py_err_se_err)?;

src/serializers/type_serializers/dataclass.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,18 @@ impl BuildSerializer for DataclassArgsBuilder {
4444
let key_py: Py<PyString> = PyString::new_bound(py, &name).into();
4545

4646
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
47-
fields.insert(name, SerField::new(py, key_py, None, None, true));
47+
fields.insert(name, SerField::new(py, key_py, None, None, true, None));
4848
} else {
4949
let schema = field_info.get_as_req(intern!(py, "schema"))?;
5050
let serializer = CombinedSerializer::build(&schema, config, definitions)
5151
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", index, e))?;
5252

5353
let alias = field_info.get_as(intern!(py, "serialization_alias"))?;
54-
fields.insert(name, SerField::new(py, key_py, alias, Some(serializer), true));
54+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
55+
fields.insert(
56+
name,
57+
SerField::new(py, key_py, alias, Some(serializer), true, exclude_if),
58+
);
5559
}
5660
}
5761

src/serializers/type_serializers/model.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,18 @@ impl BuildSerializer for ModelFieldsBuilder {
5454
let key_py: Py<PyString> = key_py.into();
5555

5656
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
57-
fields.insert(key, SerField::new(py, key_py, None, None, true));
57+
fields.insert(key, SerField::new(py, key_py, None, None, true, None));
5858
} else {
5959
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
60-
60+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
6161
let schema = field_info.get_as_req(intern!(py, "schema"))?;
6262
let serializer = CombinedSerializer::build(&schema, config, definitions)
6363
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;
6464

65-
fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), true));
65+
fields.insert(
66+
key,
67+
SerField::new(py, key_py, alias, Some(serializer), true, exclude_if),
68+
);
6669
}
6770
}
6871

src/serializers/type_serializers/typed_dict.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,17 @@ impl BuildSerializer for TypedDictBuilder {
5252
let required = field_info.get_as(intern!(py, "required"))?.unwrap_or(total);
5353

5454
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
55-
fields.insert(key, SerField::new(py, key_py, None, None, required));
55+
fields.insert(key, SerField::new(py, key_py, None, None, required, None));
5656
} else {
5757
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
58-
58+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
5959
let schema = field_info.get_as_req(intern!(py, "schema"))?;
6060
let serializer = CombinedSerializer::build(&schema, config, definitions)
6161
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;
62-
fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), required));
62+
fields.insert(
63+
key,
64+
SerField::new(py, key_py, alias, Some(serializer), required, exclude_if),
65+
);
6366
}
6467
}
6568

tests/serializers/test_dataclasses.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_serialization_exclude():
5454
core_schema.dataclass_args_schema(
5555
'Foo',
5656
[
57-
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
57+
core_schema.dataclass_field(name='a', schema=core_schema.str_schema(), exclude_if=lambda x: x == 'bye'),
5858
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True),
5959
],
6060
),
@@ -63,12 +63,18 @@ def test_serialization_exclude():
6363
s = SchemaSerializer(schema)
6464
assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'}
6565
assert s.to_python(Foo(a='hello', b=b'more'), mode='json') == {'a': 'hello'}
66+
# a = 'bye' excludes it
67+
assert s.to_python(Foo(a='bye', b=b'more'), mode='json') == {}
6668
j = s.to_json(Foo(a='hello', b=b'more'))
67-
6869
if on_pypy:
6970
assert json.loads(j) == {'a': 'hello'}
7071
else:
7172
assert j == b'{"a":"hello"}'
73+
j = s.to_json(Foo(a='bye', b=b'more'))
74+
if on_pypy:
75+
assert json.loads(j) == {}
76+
else:
77+
assert j == b'{}'
7278

7379

7480
def test_serialization_alias():

tests/serializers/test_functions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,9 @@ def __init__(self, **kwargs):
511511
MyModel,
512512
core_schema.typed_dict_schema(
513513
{
514-
'a': core_schema.typed_dict_field(core_schema.any_schema()),
514+
'a': core_schema.typed_dict_field(
515+
core_schema.any_schema(), exclude_if=lambda x: isinstance(x, int) and x >= 2
516+
),
515517
'b': core_schema.typed_dict_field(core_schema.any_schema()),
516518
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
517519
}
@@ -535,6 +537,14 @@ def __init__(self, **kwargs):
535537
assert s.to_json(m, exclude={'b'}) == b'{"a":1}'
536538
assert calls == 6
537539

540+
m = MyModel(a=2, b=b'foobar', c='excluded')
541+
assert s.to_python(m) == {'b': b'foobar'}
542+
assert calls == 7
543+
assert s.to_python(m, mode='json') == {'b': 'foobar'}
544+
assert calls == 8
545+
assert s.to_json(m) == b'{"b":"foobar"}'
546+
assert calls == 9
547+
538548

539549
def test_function_plain_model():
540550
calls = 0
@@ -553,7 +563,7 @@ def __init__(self, **kwargs):
553563
MyModel,
554564
core_schema.typed_dict_schema(
555565
{
556-
'a': core_schema.typed_dict_field(core_schema.any_schema()),
566+
'a': core_schema.typed_dict_field(core_schema.any_schema(), exclude_if=lambda x: x == 100),
557567
'b': core_schema.typed_dict_field(core_schema.any_schema()),
558568
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
559569
}

tests/serializers/test_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,32 @@ def test_include_exclude_args(params):
203203
assert json.loads(s.to_json(value, include=include, exclude=exclude)) == expected
204204

205205

206+
def test_exclude_if():
207+
s = SchemaSerializer(
208+
core_schema.model_schema(
209+
BasicModel,
210+
core_schema.model_fields_schema(
211+
{
212+
'a': core_schema.model_field(core_schema.int_schema(), exclude_if=lambda x: x > 1),
213+
'b': core_schema.model_field(core_schema.str_schema(), exclude_if=lambda x: 'foo' in x),
214+
'c': core_schema.model_field(
215+
core_schema.str_schema(), serialization_exclude=True, exclude_if=lambda x: 'foo' in x
216+
),
217+
}
218+
),
219+
)
220+
)
221+
assert s.to_python(BasicModel(a=0, b='bar', c='bar')) == {'a': 0, 'b': 'bar'}
222+
assert s.to_python(BasicModel(a=2, b='bar', c='bar')) == {'b': 'bar'}
223+
assert s.to_python(BasicModel(a=0, b='foo', c='bar')) == {'a': 0}
224+
assert s.to_python(BasicModel(a=2, b='foo', c='bar')) == {}
225+
226+
assert s.to_json(BasicModel(a=0, b='bar', c='bar')) == b'{"a":0,"b":"bar"}'
227+
assert s.to_json(BasicModel(a=2, b='bar', c='bar')) == b'{"b":"bar"}'
228+
assert s.to_json(BasicModel(a=0, b='foo', c='bar')) == b'{"a":0}'
229+
assert s.to_json(BasicModel(a=2, b='foo', c='bar')) == b'{}'
230+
231+
206232
def test_alias():
207233
s = SchemaSerializer(
208234
core_schema.model_schema(

tests/serializers/test_typed_dict.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ def test_include_exclude_schema():
9292
{
9393
'0': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=True),
9494
'1': core_schema.typed_dict_field(core_schema.int_schema()),
95-
'2': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=True),
96-
'3': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=False),
95+
'2': core_schema.typed_dict_field(
96+
core_schema.int_schema(), serialization_exclude=True, exclude_if=lambda x: x < 0
97+
),
98+
'3': core_schema.typed_dict_field(
99+
core_schema.int_schema(), serialization_exclude=False, exclude_if=lambda x: x < 0
100+
),
97101
}
98102
)
99103
)
@@ -102,6 +106,11 @@ def test_include_exclude_schema():
102106
assert s.to_python(value, mode='json') == {'1': 1, '3': 3}
103107
assert json.loads(s.to_json(value)) == {'1': 1, '3': 3}
104108

109+
value = {'0': 0, '1': 1, '2': 2, '3': -3}
110+
assert s.to_python(value) == {'1': 1}
111+
assert s.to_python(value, mode='json') == {'1': 1}
112+
assert json.loads(s.to_json(value)) == {'1': 1}
113+
105114

106115
def test_alias():
107116
s = SchemaSerializer(

0 commit comments

Comments
 (0)