Skip to content

Commit db7b731

Browse files
authored
Merge pull request #7 from Mng-dev-ai/add-call-to-schema-trait
Add call to schema trait
2 parents 7dba5fd + f516bd5 commit db7b731

File tree

3 files changed

+75
-34
lines changed

3 files changed

+75
-34
lines changed

schemars/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __new__(cls, **kwargs):
4040
strict=kwargs.get("strict", False),
4141
default=kwargs.get("default", None),
4242
source=kwargs.get("source", None),
43+
call=kwargs.get("call", False),
4344
serialize_func=kwargs.get("serialize_func", None),
4445
context=kwargs.get("context", {}),
4546
)
@@ -60,6 +61,7 @@ def __init__(self, **kwargs):
6061
self.source = kwargs.get("source", None)
6162
self.write_only = kwargs.get("write_only", False)
6263
self.strict = kwargs.get("strict", False)
64+
self.call = kwargs.get("call", False)
6365
self.default = kwargs.get("default", None)
6466
self.serialize_func = kwargs.get("serialize_func", None)
6567
self.context = kwargs.get("context", {})

src/schema.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,15 @@ impl_py_methods!(Schema, required, { fields: HashMap<String, Field>, context: Ha
7676
many: Option<bool>,
7777
parent: Option<PyObject>,
7878
) -> PyResult<PyObject> {
79+
if instance.is_none() {
80+
return Ok(py.None());
81+
}
82+
7983
if let Some(callback) = &self.base.serialize_func {
8084
let result = callback.call1(py, (instance,))?;
8185
return Ok(result);
8286
}
87+
8388
if let Some(true) = many {
8489
let pylist = instance.downcast::<PyList>()?;
8590
let mut results: Vec<PyObject> = Vec::with_capacity(pylist.len());
@@ -153,4 +158,7 @@ impl FieldTrait for Schema {
153158
fn is_method_field(&self) -> bool {
154159
self.base.is_method_field
155160
}
161+
fn call(&self) -> bool {
162+
self.base.call
163+
}
156164
}

tests/test_schema.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,154 +1,185 @@
11
import pytest
22
import schemars
33

4+
45
def test_serialize_many():
56
class Product:
67
def __init__(self, name):
78
self.name = name
9+
810
class ProductSchema(schemars.Schema):
911
name = schemars.Str()
10-
12+
1113
schema = ProductSchema()
1214
products = [
1315
Product("Product 1"),
1416
Product("Product 2"),
1517
]
16-
result = schema.serialize(products,many=True)
18+
result = schema.serialize(products, many=True)
1719
assert result == [{"name": "Product 1"}, {"name": "Product 2"}]
1820

21+
1922
def test_serialize_with_default():
2023
class Product:
2124
def __init__(self, name=None):
2225
self.name = name
26+
2327
class ProductSchema(schemars.Schema):
2428
name = schemars.Str(default="Product 1")
25-
29+
2630
schema = ProductSchema()
2731
product = Product()
2832
result = schema.serialize(product)
2933
assert result == {"name": "Product 1"}
30-
34+
35+
3136
def test_serialize_with_default_none():
3237
class Product:
3338
def __init__(self, name=None):
3439
self.name = name
40+
3541
class ProductSchema(schemars.Schema):
3642
name = schemars.Str(default=None)
37-
43+
3844
schema = ProductSchema()
3945
product = Product()
4046
result = schema.serialize(product)
4147
assert result == {"name": None}
42-
48+
49+
4350
def test_serialize_with_write_only():
4451
class Product:
4552
def __init__(self, name):
4653
self.name = name
54+
4755
class ProductSchema(schemars.Schema):
4856
name = schemars.Str(write_only=True)
49-
57+
5058
schema = ProductSchema()
5159
product = Product("Product 1")
5260
result = schema.serialize(product)
5361
assert result == {}
54-
62+
63+
5564
def test_serialize_with_source():
5665
class User:
5766
def __init__(self, name, age):
5867
self.name = name
5968
self.age = age
60-
69+
6170
class Product:
6271
def __init__(self, user):
6372
self.user = user
64-
73+
6574
class ProductSchema(schemars.Schema):
6675
name = schemars.Str(source="user.name")
6776
age = schemars.Int(source="user.age")
68-
6977

7078
schema = ProductSchema()
7179
user = User("John", 30)
7280
product = Product(user)
7381
result = schema.serialize(product)
7482
assert result == {"name": "John", "age": 30}
75-
83+
84+
7685
def test_serialize_with_call():
7786
def custom_func():
7887
return "test"
79-
88+
89+
def get_tags():
90+
return [
91+
Tag("tag1"),
92+
Tag("tag2"),
93+
]
94+
8095
class Product:
8196
@property
8297
def test(self):
8398
return custom_func
84-
99+
100+
@property
101+
def tags(self):
102+
return get_tags
103+
104+
class Tag:
105+
def __init__(self, name):
106+
self.name = name
107+
108+
class TagSchema(schemars.Schema):
109+
name = schemars.Str()
110+
85111
class ProductSchema(schemars.Schema):
86112
test = schemars.Str(call=True)
87-
113+
tags = TagSchema(many=True, call=True)
114+
88115
schema = ProductSchema()
89116
product = Product()
90117
result = schema.serialize(product)
91-
assert result == {"test": "test"}
92-
118+
assert result == {"test": "test", "tags": [{"name": "tag1"}, {"name": "tag2"}]}
119+
120+
93121
def test_serialize_with_serialize_func():
94122
class ProductSchema(schemars.Schema):
95123
name = schemars.Str(serialize_func=lambda name: name.upper())
96-
124+
97125
class Product:
98126
def __init__(self, name):
99127
self.name = name
100-
128+
101129
schema = ProductSchema()
102130
product = Product("Product 1")
103131
result = schema.serialize(product)
104132
assert result == {"name": "PRODUCT 1"}
105-
133+
134+
106135
def test_serialize_with_nested():
107136
class User:
108137
def __init__(self, name, age):
109138
self.name = name
110139
self.age = age
111-
140+
112141
class Product:
113142
def __init__(self, user):
114143
self.user = user
115-
144+
116145
class UserSchema(schemars.Schema):
117146
name = schemars.Str()
118147
age = schemars.Int()
119-
148+
120149
class ProductSchema(schemars.Schema):
121150
user = UserSchema()
122-
151+
123152
schema = ProductSchema()
124153
user = User("John", 30)
125154
product = Product(user)
126155
result = schema.serialize(product)
127156
assert result == {"user": {"name": "John", "age": 30}}
128-
157+
158+
129159
def test_serialize_with_nested_many():
130160
class User:
131161
def __init__(self, name, age):
132162
self.name = name
133163
self.age = age
134-
164+
135165
class Product:
136166
def __init__(self, users):
137167
self.users = users
138-
168+
139169
class UserSchema(schemars.Schema):
140170
name = schemars.Str()
141171
age = schemars.Int()
142-
172+
143173
class ProductSchema(schemars.Schema):
144174
users = UserSchema(many=True)
145-
175+
146176
schema = ProductSchema()
147177
user = User("John", 30)
148178
product = Product([user])
149179
result = schema.serialize(product)
150180
assert result == {"users": [{"name": "John", "age": 30}]}
151181

182+
152183
def test_serialize_with_context():
153184
class Product:
154185
def __init__(self, name):
@@ -158,16 +189,16 @@ class ProductSchema(schemars.Schema):
158189
name = schemars.Str()
159190

160191
def serialize(self, instance, many=None):
161-
context_suffix = self.context.get('suffix') if self.context else ''
192+
context_suffix = self.context.get("suffix") if self.context else ""
162193
result = super().serialize(instance, many)
163194
if not many:
164-
result['name'] += context_suffix
195+
result["name"] += context_suffix
165196
else:
166197
for item in result:
167-
item['name'] += context_suffix
198+
item["name"] += context_suffix
168199
return result
169-
170-
schema = ProductSchema(context={'suffix': ' - Context'})
200+
201+
schema = ProductSchema(context={"suffix": " - Context"})
171202
product = Product("Product 1")
172203
result = schema.serialize(product)
173204
assert result == {"name": "Product 1 - Context"}

0 commit comments

Comments
 (0)