Skip to content

Commit eb2e412

Browse files
andreslisztAndres
authored andcommitted
Support exclude_if callable at field level
1 parent 0a5bbfc commit eb2e412

File tree

9 files changed

+142
-38
lines changed

9 files changed

+142
-38
lines changed

python/pydantic_core/core_schema.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2839,6 +2839,7 @@ class TypedDictField(TypedDict, total=False):
28392839
serialization_alias: str
28402840
serialization_exclude: bool # default: False
28412841
metadata: dict[str, Any]
2842+
serialization_exclude_if: Callable[[Any], bool] # default None
28422843

28432844

28442845
def typed_dict_field(
@@ -2849,6 +2850,7 @@ def typed_dict_field(
28492850
serialization_alias: str | None = None,
28502851
serialization_exclude: bool | None = None,
28512852
metadata: dict[str, Any] | None = None,
2853+
serialization_exclude_if: Callable[[Any], bool] | None = None,
28522854
) -> TypedDictField:
28532855
"""
28542856
Returns a schema that matches a typed dict field, e.g.:
@@ -2865,6 +2867,7 @@ def typed_dict_field(
28652867
validation_alias: The alias(es) to use to find the field in the validation data
28662868
serialization_alias: The alias to use as a key when serializing
28672869
serialization_exclude: Whether to exclude the field when serializing
2870+
serialization_exclude_if: A callable that determines whether to exclude the field when serializing based on its value.
28682871
metadata: Any other information you want to include with the schema, not used by pydantic-core
28692872
"""
28702873
return _dict_not_none(
@@ -2874,6 +2877,7 @@ def typed_dict_field(
28742877
validation_alias=validation_alias,
28752878
serialization_alias=serialization_alias,
28762879
serialization_exclude=serialization_exclude,
2880+
serialization_exclude_if=serialization_exclude_if,
28772881
metadata=metadata,
28782882
)
28792883

@@ -2965,6 +2969,7 @@ class ModelField(TypedDict, total=False):
29652969
validation_alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]]
29662970
serialization_alias: str
29672971
serialization_exclude: bool # default: False
2972+
serialization_exclude_if: Callable[[Any], bool] # default: None
29682973
frozen: bool
29692974
metadata: dict[str, Any]
29702975

@@ -2975,6 +2980,7 @@ def model_field(
29752980
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
29762981
serialization_alias: str | None = None,
29772982
serialization_exclude: bool | None = None,
2983+
exclude_if: Callable[[Any], bool] | None = None,
29782984
frozen: bool | None = None,
29792985
metadata: dict[str, Any] | None = None,
29802986
) -> ModelField:
@@ -2992,6 +2998,7 @@ def model_field(
29922998
validation_alias: The alias(es) to use to find the field in the validation data
29932999
serialization_alias: The alias to use as a key when serializing
29943000
serialization_exclude: Whether to exclude the field when serializing
3001+
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
29953002
frozen: Whether the field is frozen
29963003
metadata: Any other information you want to include with the schema, not used by pydantic-core
29973004
"""
@@ -3001,6 +3008,7 @@ def model_field(
30013008
validation_alias=validation_alias,
30023009
serialization_alias=serialization_alias,
30033010
serialization_exclude=serialization_exclude,
3011+
exclude_if=exclude_if,
30043012
frozen=frozen,
30053013
metadata=metadata,
30063014
)
@@ -3193,6 +3201,7 @@ class DataclassField(TypedDict, total=False):
31933201
serialization_alias: str
31943202
serialization_exclude: bool # default: False
31953203
metadata: dict[str, Any]
3204+
serialization_exclude_if: Callable[[Any], bool] # default: None
31963205

31973206

31983207
def dataclass_field(
@@ -3206,6 +3215,7 @@ def dataclass_field(
32063215
serialization_alias: str | None = None,
32073216
serialization_exclude: bool | None = None,
32083217
metadata: dict[str, Any] | None = None,
3218+
serialization_exclude_if: Callable[[Any], bool] | None = None,
32093219
frozen: bool | None = None,
32103220
) -> DataclassField:
32113221
"""
@@ -3231,6 +3241,7 @@ def dataclass_field(
32313241
validation_alias: The alias(es) to use to find the field in the validation data
32323242
serialization_alias: The alias to use as a key when serializing
32333243
serialization_exclude: Whether to exclude the field when serializing
3244+
serialization_exclude_if: A callable that determines whether to exclude the field when serializing based on its value.
32343245
metadata: Any other information you want to include with the schema, not used by pydantic-core
32353246
frozen: Whether the field is frozen
32363247
"""
@@ -3244,6 +3255,7 @@ def dataclass_field(
32443255
validation_alias=validation_alias,
32453256
serialization_alias=serialization_alias,
32463257
serialization_exclude=serialization_exclude,
3258+
serialization_exclude_if=serialization_exclude_if,
32473259
metadata=metadata,
32483260
frozen=frozen,
32493261
)

src/serializers/fields.rs

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub(super) struct SerField {
2929
pub serializer: Option<CombinedSerializer>,
3030
pub required: bool,
3131
pub serialize_by_alias: Option<bool>,
32+
pub exclude_if: Option<Py<PyAny>>,
3233
}
3334

3435
impl_py_gc_traverse!(SerField { serializer });
@@ -41,6 +42,7 @@ impl SerField {
4142
serializer: Option<CombinedSerializer>,
4243
required: bool,
4344
serialize_by_alias: Option<bool>,
45+
exclude_if: Option<Py<PyAny>>,
4446
) -> Self {
4547
let alias_py = alias.as_ref().map(|alias| PyString::new(py, alias.as_str()).into());
4648
Self {
@@ -50,6 +52,7 @@ impl SerField {
5052
serializer,
5153
required,
5254
serialize_by_alias,
55+
exclude_if,
5356
}
5457
}
5558

@@ -72,6 +75,18 @@ impl SerField {
7275
}
7376
}
7477

78+
fn exclude_if(exclude_if_callable: &Option<Py<PyAny>>, value: &Bound<'_, PyAny>) -> PyResult<bool> {
79+
if let Some(exclude_if_callable) = exclude_if_callable {
80+
let py = value.py();
81+
let result = exclude_if_callable.call1(py, (value,))?;
82+
let exclude = result.extract::<bool>(py)?;
83+
if exclude {
84+
return Ok(true);
85+
}
86+
}
87+
Ok(false)
88+
}
89+
7590
fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
7691
if extra.exclude_defaults {
7792
if let Some(default) = serializer.get_default(value.py())? {
@@ -176,16 +191,16 @@ impl GeneralFieldsSerializer {
176191
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
177192
if let Some(field) = op_field {
178193
if let Some(ref serializer) = field.serializer {
179-
if !exclude_default(&value, &field_extra, serializer)? {
180-
let value = serializer.to_python(
181-
&value,
182-
next_include.as_ref(),
183-
next_exclude.as_ref(),
184-
&field_extra,
185-
)?;
186-
let output_key = field.get_key_py(output_dict.py(), &field_extra);
187-
output_dict.set_item(output_key, value)?;
194+
if exclude_default(&value, &field_extra, serializer)? {
195+
continue;
188196
}
197+
if exclude_if(&field.exclude_if, &value)? {
198+
continue;
199+
}
200+
let value =
201+
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?;
202+
let output_key = field.get_key_py(output_dict.py(), &field_extra);
203+
output_dict.set_item(output_key, value)?;
189204
}
190205

191206
if field.required {
@@ -257,17 +272,21 @@ impl GeneralFieldsSerializer {
257272
if let Some((next_include, next_exclude)) = filter {
258273
if let Some(field) = self.fields.get(key_str) {
259274
if let Some(ref serializer) = field.serializer {
260-
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
261-
let s = PydanticSerializer::new(
262-
&value,
263-
serializer,
264-
next_include.as_ref(),
265-
next_exclude.as_ref(),
266-
&field_extra,
267-
);
268-
let output_key = field.get_key_json(key_str, &field_extra);
269-
map.serialize_entry(&output_key, &s)?;
275+
if exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
276+
continue;
277+
}
278+
if exclude_if(&field.exclude_if, &value).map_err(py_err_se_err)? {
279+
continue;
270280
}
281+
let s = PydanticSerializer::new(
282+
&value,
283+
serializer,
284+
next_include.as_ref(),
285+
next_exclude.as_ref(),
286+
&field_extra,
287+
);
288+
let output_key = field.get_key_json(key_str, &field_extra);
289+
map.serialize_entry(&output_key, &s)?;
271290
}
272291
} else if self.mode == FieldsMode::TypedDictAllow {
273292
let output_key = infer_json_key(&key, &field_extra).map_err(py_err_se_err)?;

src/serializers/type_serializers/dataclass.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,27 @@ impl BuildSerializer for DataclassArgsBuilder {
4444
let name: String = field_info.get_as_req(intern!(py, "name"))?;
4545

4646
let key_py: Py<PyString> = PyString::new(py, &name).into();
47-
4847
if !field_info.get_as(intern!(py, "init_only"))?.unwrap_or(false) {
4948
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
50-
fields.insert(name, SerField::new(py, key_py, None, None, true, serialize_by_alias));
49+
fields.insert(name, SerField::new(py, key_py, None, None, true, None));
5150
} else {
5251
let schema = field_info.get_as_req(intern!(py, "schema"))?;
5352
let serializer = CombinedSerializer::build(&schema, config, definitions)
5453
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", index, e))?;
5554

5655
let alias = field_info.get_as(intern!(py, "serialization_alias"))?;
56+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
5757
fields.insert(
5858
name,
59-
SerField::new(py, key_py, alias, Some(serializer), true, serialize_by_alias),
59+
SerField::new(py, key_py, alias, Some(serializer), true, exclude_if),
6060
);
6161
}
62-
}
63-
}
62+
};
6463

65-
let computed_fields = ComputedFields::new(schema, config, definitions)?;
64+
let computed_fields = ComputedFields::new(schema, config, definitions)?;
6665

67-
Ok(GeneralFieldsSerializer::new(fields, fields_mode, None, computed_fields).into())
66+
Ok(GeneralFieldsSerializer::new(fields, fields_mode, None, computed_fields).into())
67+
}
6868
}
6969
}
7070

src/serializers/type_serializers/model.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,28 @@ impl BuildSerializer for ModelFieldsBuilder {
5757
let key_py: Py<PyString> = key_py.into();
5858

5959
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
60-
fields.insert(key, SerField::new(py, key_py, None, None, true, serialize_by_alias));
60+
fields.insert(
61+
key,
62+
SerField::new(py, key_py, None, None, true, serialize_by_alias, None),
63+
);
6164
} else {
6265
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
63-
66+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
6467
let schema = field_info.get_as_req(intern!(py, "schema"))?;
6568
let serializer = CombinedSerializer::build(&schema, config, definitions)
6669
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;
6770

6871
fields.insert(
6972
key,
70-
SerField::new(py, key_py, alias, Some(serializer), true, serialize_by_alias),
73+
SerField::new(
74+
py,
75+
key_py,
76+
alias,
77+
Some(serializer),
78+
true,
79+
serialize_by_alias,
80+
exclude_if,
81+
),
7182
);
7283
}
7384
}

src/serializers/type_serializers/typed_dict.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,27 @@ impl BuildSerializer for TypedDictBuilder {
5454
let required = field_info.get_as(intern!(py, "required"))?.unwrap_or(total);
5555

5656
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
57-
fields.insert(key, SerField::new(py, key_py, None, None, required, serialize_by_alias));
57+
fields.insert(
58+
key,
59+
SerField::new(py, key_py, None, None, required, serialize_by_alias, None),
60+
);
5861
} else {
5962
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
60-
63+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
6164
let schema = field_info.get_as_req(intern!(py, "schema"))?;
6265
let serializer = CombinedSerializer::build(&schema, config, definitions)
6366
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;
6467
fields.insert(
6568
key,
66-
SerField::new(py, key_py, alias, Some(serializer), required, serialize_by_alias),
69+
SerField::new(
70+
py,
71+
key_py,
72+
alias,
73+
Some(serializer),
74+
required,
75+
serialize_by_alias,
76+
exclude_if,
77+
),
6778
);
6879
}
6980
}

tests/serializers/test_dataclasses.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_serialization_exclude():
5454
core_schema.dataclass_args_schema(
5555
'Foo',
5656
[
57-
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
57+
core_schema.dataclass_field(name='a', schema=core_schema.str_schema(), exclude_if=lambda x: x == 'bye'),
5858
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True),
5959
],
6060
),
@@ -63,12 +63,18 @@ def test_serialization_exclude():
6363
s = SchemaSerializer(schema)
6464
assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'}
6565
assert s.to_python(Foo(a='hello', b=b'more'), mode='json') == {'a': 'hello'}
66+
# a = 'bye' excludes it
67+
assert s.to_python(Foo(a='bye', b=b'more'), mode='json') == {}
6668
j = s.to_json(Foo(a='hello', b=b'more'))
67-
6869
if on_pypy:
6970
assert json.loads(j) == {'a': 'hello'}
7071
else:
7172
assert j == b'{"a":"hello"}'
73+
j = s.to_json(Foo(a='bye', b=b'more'))
74+
if on_pypy:
75+
assert json.loads(j) == {}
76+
else:
77+
assert j == b'{}'
7278

7379

7480
def test_serialization_alias():

tests/serializers/test_functions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,9 @@ def __init__(self, **kwargs):
517517
MyModel,
518518
core_schema.typed_dict_schema(
519519
{
520-
'a': core_schema.typed_dict_field(core_schema.any_schema()),
520+
'a': core_schema.typed_dict_field(
521+
core_schema.any_schema(), exclude_if=lambda x: isinstance(x, int) and x >= 2
522+
),
521523
'b': core_schema.typed_dict_field(core_schema.any_schema()),
522524
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
523525
}
@@ -541,6 +543,14 @@ def __init__(self, **kwargs):
541543
assert s.to_json(m, exclude={'b'}) == b'{"a":1}'
542544
assert calls == 6
543545

546+
m = MyModel(a=2, b=b'foobar', c='excluded')
547+
assert s.to_python(m) == {'b': b'foobar'}
548+
assert calls == 7
549+
assert s.to_python(m, mode='json') == {'b': 'foobar'}
550+
assert calls == 8
551+
assert s.to_json(m) == b'{"b":"foobar"}'
552+
assert calls == 9
553+
544554

545555
def test_function_plain_model():
546556
calls = 0
@@ -559,7 +569,7 @@ def __init__(self, **kwargs):
559569
MyModel,
560570
core_schema.typed_dict_schema(
561571
{
562-
'a': core_schema.typed_dict_field(core_schema.any_schema()),
572+
'a': core_schema.typed_dict_field(core_schema.any_schema(), exclude_if=lambda x: x == 100),
563573
'b': core_schema.typed_dict_field(core_schema.any_schema()),
564574
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
565575
}

tests/serializers/test_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,32 @@ def test_include_exclude_args(params):
203203
assert json.loads(s.to_json(value, include=include, exclude=exclude)) == expected
204204

205205

206+
def test_exclude_if():
207+
s = SchemaSerializer(
208+
core_schema.model_schema(
209+
BasicModel,
210+
core_schema.model_fields_schema(
211+
{
212+
'a': core_schema.model_field(core_schema.int_schema(), exclude_if=lambda x: x > 1),
213+
'b': core_schema.model_field(core_schema.str_schema(), exclude_if=lambda x: 'foo' in x),
214+
'c': core_schema.model_field(
215+
core_schema.str_schema(), serialization_exclude=True, exclude_if=lambda x: 'foo' in x
216+
),
217+
}
218+
),
219+
)
220+
)
221+
assert s.to_python(BasicModel(a=0, b='bar', c='bar')) == {'a': 0, 'b': 'bar'}
222+
assert s.to_python(BasicModel(a=2, b='bar', c='bar')) == {'b': 'bar'}
223+
assert s.to_python(BasicModel(a=0, b='foo', c='bar')) == {'a': 0}
224+
assert s.to_python(BasicModel(a=2, b='foo', c='bar')) == {}
225+
226+
assert s.to_json(BasicModel(a=0, b='bar', c='bar')) == b'{"a":0,"b":"bar"}'
227+
assert s.to_json(BasicModel(a=2, b='bar', c='bar')) == b'{"b":"bar"}'
228+
assert s.to_json(BasicModel(a=0, b='foo', c='bar')) == b'{"a":0}'
229+
assert s.to_json(BasicModel(a=2, b='foo', c='bar')) == b'{}'
230+
231+
206232
def test_alias():
207233
s = SchemaSerializer(
208234
core_schema.model_schema(

0 commit comments

Comments
 (0)