Skip to content

Commit b46fd16

Browse files
committed
WIP: add fraction
1 parent 70bd6f9 commit b46fd16

File tree

6 files changed

+213
-5
lines changed

6 files changed

+213
-5
lines changed

python/pydantic_core/core_schema.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,62 @@ def decimal_schema(
809809
serialization=serialization,
810810
)
811811

812+
def fraction_schema(
813+
*,
814+
allow_inf_nan: bool | None = None,
815+
multiple_of: Fraction | None = None,
816+
le: Fraction | None = None,
817+
ge: Fraction | None = None,
818+
lt: Fraction | None = None,
819+
gt: Fraction | None = None,
820+
max_digits: int | None = None,
821+
decimal_places: int | None = None,
822+
strict: bool | None = None,
823+
ref: str | None = None,
824+
metadata: dict[str, Any] | None = None,
825+
serialization: SerSchema | None = None,
826+
) -> FractionSchema:
827+
"""
828+
Returns a schema that matches a decimal value, e.g.:
829+
830+
```py
831+
from fractions import Fraction
832+
from pydantic_core import SchemaValidator, core_schema
833+
834+
schema = core_schema.fraction_schema(le=0.8, ge=0.2)
835+
v = SchemaValidator(schema)
836+
assert v.validate_python(1, 2) == Fraction(1, 2)
837+
```
838+
839+
Args:
840+
allow_inf_nan: Whether to allow inf and nan values
841+
multiple_of: The value must be a multiple of this number
842+
le: The value must be less than or equal to this number
843+
ge: The value must be greater than or equal to this number
844+
lt: The value must be strictly less than this number
845+
gt: The value must be strictly greater than this number
846+
max_digits: The maximum number of decimal digits allowed
847+
decimal_places: The maximum number of decimal places allowed
848+
strict: Whether the value should be a float or a value that can be converted to a float
849+
ref: optional unique identifier of the schema, used to reference the schema in other places
850+
metadata: Any other information you want to include with the schema, not used by pydantic-core
851+
serialization: Custom serialization schema
852+
"""
853+
return _dict_not_none(
854+
type='fraction',
855+
gt=gt,
856+
ge=ge,
857+
lt=lt,
858+
le=le,
859+
max_digits=max_digits,
860+
decimal_places=decimal_places,
861+
multiple_of=multiple_of,
862+
allow_inf_nan=allow_inf_nan,
863+
strict=strict,
864+
ref=ref,
865+
metadata=metadata,
866+
serialization=serialization,
867+
)
812868

813869
class ComplexSchema(TypedDict, total=False):
814870
type: Required[Literal['complex']]

src/serializers/infer.rs

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,16 @@ pub(crate) fn infer_to_python_known(
136136
}
137137
v.into_py_any(py)?
138138
}
139-
ObType::Decimal => value.to_string().into_py_any(py)?,
139+
ObType::Decimal => {
140+
// todo: delete before PR ready
141+
println!("[RUST] infer_to_python_known - SerMode::Json - serializing ObType::Decimal");
142+
value.to_string().into_py_any(py)?
143+
},
144+
ObType::Fraction => {
145+
// todo: delete before PR ready
146+
println!("[RUST] infer_to_python_known - SerMode::Json - serializing ObType::Fraction");
147+
value.to_string().into_py_any(py)?
148+
},
140149
ObType::StrSubclass => PyString::new(py, value.downcast::<PyString>()?.to_str()?).into(),
141150
ObType::Bytes => extra
142151
.config
@@ -430,7 +439,16 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
430439
let v = value.extract::<f64>().map_err(py_err_se_err)?;
431440
type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode)
432441
}
433-
ObType::Decimal => value.to_string().serialize(serializer),
442+
ObType::Decimal => {
443+
// todo: delete before PR ready
444+
println!("[RUST] infer_serialize_known - serializing ObType::Decimal");
445+
value.to_string().serialize(serializer)
446+
},
447+
ObType::Fraction => {
448+
// todo: delete before PR ready
449+
println!("[RUST] infer_serialize_known - serializing ObType::Fraction");
450+
value.to_string().serialize(serializer)
451+
},
434452
ObType::Str | ObType::StrSubclass => {
435453
let py_str = value.downcast::<PyString>().map_err(py_err_se_err)?;
436454
super::type_serializers::string::serialize_py_str(py_str, serializer)
@@ -612,7 +630,16 @@ pub(crate) fn infer_json_key_known<'a>(
612630
super::type_serializers::simple::to_str_json_key(key)
613631
}
614632
}
615-
ObType::Decimal => Ok(Cow::Owned(key.to_string())),
633+
ObType::Decimal => {
634+
// todo: delete before PR ready
635+
println!("[RUST] infer_json_key_known - converting ObType::Decimal to json key");
636+
Ok(Cow::Owned(key.to_string()))
637+
},
638+
ObType::Fraction => {
639+
// todo: delete before PR ready
640+
println!("[RUST] infer_json_key_known - converting ObType::Fraction to json key");
641+
Ok(Cow::Owned(key.to_string()))
642+
},
616643
ObType::Bool => super::type_serializers::simple::bool_json_key(key),
617644
ObType::Str | ObType::StrSubclass => key.downcast::<PyString>()?.to_cow(),
618645
ObType::Bytes => extra

src/serializers/ob_type.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct ObTypeLookup {
2323
dict: usize,
2424
// other numeric types
2525
decimal_object: Py<PyAny>,
26+
fraction_object: Py<PyAny>,
2627
// other string types
2728
bytes: usize,
2829
bytearray: usize,
@@ -62,14 +63,25 @@ pub enum IsType {
6263

6364
impl ObTypeLookup {
6465
fn new(py: Python) -> Self {
66+
// todo: delete before PR ready
67+
println!("[RUST] ObTypeLookup::new");
6568
Self {
6669
none: PyNone::type_object_raw(py) as usize,
6770
int: PyInt::type_object_raw(py) as usize,
6871
bool: PyBool::type_object_raw(py) as usize,
6972
float: PyFloat::type_object_raw(py) as usize,
7073
list: PyList::type_object_raw(py) as usize,
7174
dict: PyDict::type_object_raw(py) as usize,
72-
decimal_object: py.import("decimal").unwrap().getattr("Decimal").unwrap().unbind(),
75+
decimal_object: {
76+
// todo: delete before PR ready
77+
println!("[RUST] ObTypeLookup::new - loading decimal_object");
78+
py.import("decimal").unwrap().getattr("Decimal").unwrap().unbind()
79+
},
80+
fraction_object: {
81+
// todo: delete before PR ready
82+
println!("[RUST] ObTypeLookup::new - loading fraction_object");
83+
py.import("fractions").unwrap().getattr("Fraction").unwrap().unbind()
84+
},
7385
string: PyString::type_object_raw(py) as usize,
7486
bytes: PyBytes::type_object_raw(py) as usize,
7587
bytearray: PyByteArray::type_object_raw(py) as usize,
@@ -96,6 +108,7 @@ impl ObTypeLookup {
96108
}
97109

98110
pub fn is_type(&self, value: &Bound<'_, PyAny>, expected_ob_type: ObType) -> IsType {
111+
println!("[RUST] is_type - expected_ob_type: {expected_ob_type}");
99112
match self.ob_type_is_expected(Some(value), &value.get_type(), &expected_ob_type) {
100113
IsType::False => {
101114
if expected_ob_type == self.fallback_isinstance(value) {
@@ -116,6 +129,7 @@ impl ObTypeLookup {
116129
) -> IsType {
117130
let type_ptr = py_type.as_ptr();
118131
let ob_type = type_ptr as usize;
132+
println!("[RUST] ob_type_is_expected - ob_type: {ob_type}, expected_ob_type: {expected_ob_type}");
119133
let ans = match expected_ob_type {
120134
ObType::None => self.none == ob_type,
121135
ObType::Int => self.int == ob_type,
@@ -137,7 +151,16 @@ impl ObTypeLookup {
137151
ObType::Str => self.string == ob_type,
138152
ObType::List => self.list == ob_type,
139153
ObType::Dict => self.dict == ob_type,
140-
ObType::Decimal => self.decimal_object.as_ptr() as usize == ob_type,
154+
ObType::Decimal => {
155+
// todo: delete before PR ready
156+
println!("[RUST] ob_type_is_expected - checking ObType::Decimal");
157+
self.decimal_object.as_ptr() as usize == ob_type
158+
},
159+
ObType::Fraction => {
160+
// todo: delete before PR ready
161+
println!("[RUST] ob_type_is_expected - checking ObType::Fraction");
162+
self.fraction_object.as_ptr() as usize == ob_type
163+
},
141164
ObType::StrSubclass => self.string == ob_type && op_value.is_none(),
142165
ObType::Tuple => self.tuple == ob_type,
143166
ObType::Set => self.set == ob_type,
@@ -214,7 +237,13 @@ impl ObTypeLookup {
214237
} else if ob_type == self.dict {
215238
ObType::Dict
216239
} else if ob_type == self.decimal_object.as_ptr() as usize {
240+
// todo: delete before PR ready
241+
println!("[RUST] lookup_by_ob_type - found ObType::Decimal");
217242
ObType::Decimal
243+
} else if ob_type == self.fraction_object.as_ptr() as usize {
244+
// todo: delete before PR ready
245+
println!("[RUST] lookup_by_ob_type - found ObType::Fraction");
246+
ObType::Fraction
218247
} else if ob_type == self.bytes {
219248
ObType::Bytes
220249
} else if ob_type == self.tuple {
@@ -322,7 +351,13 @@ impl ObTypeLookup {
322351
} else if value.is_instance_of::<PyMultiHostUrl>() {
323352
ObType::MultiHostUrl
324353
} else if value.is_instance(self.decimal_object.bind(py)).unwrap_or(false) {
354+
// todo: delete before PR ready
355+
println!("[RUST] fallback_isinstance - found ObType::Decimal");
325356
ObType::Decimal
357+
} else if value.is_instance(self.fraction_object.bind(py)).unwrap_or(false) {
358+
// todo: delete before PR ready
359+
println!("[RUST] fallback_isinstance - found ObType::Fraction");
360+
ObType::Fraction
326361
} else if value.is_instance(self.uuid_object.bind(py)).unwrap_or(false) {
327362
ObType::Uuid
328363
} else if value.is_instance(self.enum_object.bind(py)).unwrap_or(false) {
@@ -380,6 +415,7 @@ pub enum ObType {
380415
Float,
381416
FloatSubclass,
382417
Decimal,
418+
Fraction,
383419
// string types
384420
Str,
385421
StrSubclass,

src/serializers/shared.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ combined_serializer! {
118118
Bool: super::type_serializers::simple::BoolSerializer;
119119
Float: super::type_serializers::float::FloatSerializer;
120120
Decimal: super::type_serializers::decimal::DecimalSerializer;
121+
Fraction: super::type_serializers::fraction::FractionSerializer;
121122
Str: super::type_serializers::string::StrSerializer;
122123
Bytes: super::type_serializers::bytes::BytesSerializer;
123124
Datetime: super::type_serializers::datetime_etc::DatetimeSerializer;
@@ -321,6 +322,7 @@ impl PyGcTraverse for CombinedSerializer {
321322
CombinedSerializer::Bool(inner) => inner.py_gc_traverse(visit),
322323
CombinedSerializer::Float(inner) => inner.py_gc_traverse(visit),
323324
CombinedSerializer::Decimal(inner) => inner.py_gc_traverse(visit),
325+
CombinedSerializer::Fraction(inner) => inner.py_gc_traverse(visit),
324326
CombinedSerializer::Str(inner) => inner.py_gc_traverse(visit),
325327
CombinedSerializer::Bytes(inner) => inner.py_gc_traverse(visit),
326328
CombinedSerializer::Datetime(inner) => inner.py_gc_traverse(visit),
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
use std::borrow::Cow;
2+
use std::sync::Arc;
3+
4+
use pyo3::prelude::*;
5+
use pyo3::types::PyDict;
6+
7+
use crate::build_tools::LazyLock;
8+
use crate::definitions::DefinitionsBuilder;
9+
use crate::serializers::infer::{infer_json_key_known, infer_serialize_known, infer_to_python_known};
10+
use crate::serializers::ob_type::{IsType, ObType};
11+
12+
use super::{
13+
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, TypeSerializer,
14+
};
15+
16+
#[derive(Debug)]
17+
pub struct FractionSerializer {}
18+
19+
static FRACTION_SERIALIZER: LazyLock<Arc<CombinedSerializer>> = LazyLock::new(|| Arc::new(FractionSerializer {}.into()));
20+
21+
impl BuildSerializer for FractionSerializer {
22+
const EXPECTED_TYPE: &'static str = "decimal";
23+
24+
fn build(
25+
_schema: &Bound<'_, PyDict>,
26+
_config: Option<&Bound<'_, PyDict>>,
27+
_definitions: &mut DefinitionsBuilder<Arc<CombinedSerializer>>,
28+
) -> PyResult<Arc<CombinedSerializer>> {
29+
Ok(FRACTION_SERIALIZER.clone())
30+
}
31+
}
32+
33+
impl_py_gc_traverse!(FractionSerializer {});
34+
35+
impl TypeSerializer for FractionSerializer {
36+
fn to_python(
37+
&self,
38+
value: &Bound<'_, PyAny>,
39+
include: Option<&Bound<'_, PyAny>>,
40+
exclude: Option<&Bound<'_, PyAny>>,
41+
extra: &Extra,
42+
) -> PyResult<Py<PyAny>> {
43+
let _py = value.py();
44+
println!("[RUST] FractionSerializer to_python called");
45+
match extra.ob_type_lookup.is_type(value, ObType::Fraction) {
46+
IsType::Exact | IsType::Subclass => infer_to_python_known(ObType::Fraction, value, include, exclude, extra),
47+
IsType::False => {
48+
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
49+
infer_to_python(value, include, exclude, extra)
50+
}
51+
}
52+
}
53+
54+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
55+
match extra.ob_type_lookup.is_type(key, ObType::Fraction) {
56+
IsType::Exact | IsType::Subclass => infer_json_key_known(ObType::Fraction, key, extra),
57+
IsType::False => {
58+
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
59+
infer_json_key(key, extra)
60+
}
61+
}
62+
}
63+
64+
fn serde_serialize<S: serde::ser::Serializer>(
65+
&self,
66+
value: &Bound<'_, PyAny>,
67+
serializer: S,
68+
include: Option<&Bound<'_, PyAny>>,
69+
exclude: Option<&Bound<'_, PyAny>>,
70+
extra: &Extra,
71+
) -> Result<S::Ok, S::Error> {
72+
match extra.ob_type_lookup.is_type(value, ObType::Fraction) {
73+
IsType::Exact | IsType::Subclass => {
74+
infer_serialize_known(ObType::Fraction, value, serializer, include, exclude, extra)
75+
}
76+
IsType::False => {
77+
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
78+
infer_serialize(value, serializer, include, exclude, extra)
79+
}
80+
}
81+
}
82+
83+
fn get_name(&self) -> &str {
84+
Self::EXPECTED_TYPE
85+
}
86+
}

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod complex;
44
pub mod dataclass;
55
pub mod datetime_etc;
66
pub mod decimal;
7+
pub mod fraction;
78
pub mod definitions;
89
pub mod dict;
910
pub mod enum_;

0 commit comments

Comments
 (0)