Skip to content

Commit 99460bf

Browse files
committed
WIP: add fraction decimal
1 parent b46fd16 commit 99460bf

File tree

9 files changed

+329
-26
lines changed

9 files changed

+329
-26
lines changed

python/pydantic_core/core_schema.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from collections.abc import Hashable, Mapping
1111
from datetime import date, datetime, time, timedelta
1212
from decimal import Decimal
13+
from fractions import Fraction
1314
from re import Pattern
1415
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
1516

@@ -809,23 +810,30 @@ def decimal_schema(
809810
serialization=serialization,
810811
)
811812

813+
class FractionSchema(TypedDict, total=False):
814+
type: Required[Literal['decimal']]
815+
le: Decimal
816+
ge: Decimal
817+
lt: Decimal
818+
gt: Decimal
819+
strict: bool
820+
ref: str
821+
metadata: dict[str, Any]
822+
serialization: SerSchema
823+
812824
def fraction_schema(
813825
*,
814-
allow_inf_nan: bool | None = None,
815-
multiple_of: Fraction | None = None,
816826
le: Fraction | None = None,
817827
ge: Fraction | None = None,
818828
lt: Fraction | None = None,
819829
gt: Fraction | None = None,
820-
max_digits: int | None = None,
821-
decimal_places: int | None = None,
822830
strict: bool | None = None,
823831
ref: str | None = None,
824832
metadata: dict[str, Any] | None = None,
825833
serialization: SerSchema | None = None,
826834
) -> FractionSchema:
827835
"""
828-
Returns a schema that matches a decimal value, e.g.:
836+
Returns a schema that matches a fraction value, e.g.:
829837
830838
```py
831839
from fractions import Fraction
@@ -837,14 +845,10 @@ def fraction_schema(
837845
```
838846
839847
Args:
840-
allow_inf_nan: Whether to allow inf and nan values
841-
multiple_of: The value must be a multiple of this number
842848
le: The value must be less than or equal to this number
843849
ge: The value must be greater than or equal to this number
844850
lt: The value must be strictly less than this number
845851
gt: The value must be strictly greater than this number
846-
max_digits: The maximum number of decimal digits allowed
847-
decimal_places: The maximum number of decimal places allowed
848852
strict: Whether the value should be a float or a value that can be converted to a float
849853
ref: optional unique identifier of the schema, used to reference the schema in other places
850854
metadata: Any other information you want to include with the schema, not used by pydantic-core
@@ -856,10 +860,6 @@ def fraction_schema(
856860
ge=ge,
857861
lt=lt,
858862
le=le,
859-
max_digits=max_digits,
860-
decimal_places=decimal_places,
861-
multiple_of=multiple_of,
862-
allow_inf_nan=allow_inf_nan,
863863
strict=strict,
864864
ref=ref,
865865
metadata=metadata,

src/errors/types.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ error_types! {
430430
DecimalWholeDigits {
431431
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
432432
},
433+
// Fraction errors
434+
FractionType {},
435+
FractionParsing {},
433436
// Complex errors
434437
ComplexType {},
435438
ComplexStrParsing {},
@@ -579,6 +582,8 @@ impl ErrorType {
579582
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
580583
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
581584
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
585+
Self::FractionParsing {..} => "Fraction input should be an integer, float, string or Fraction object",
586+
Self::FractionType {..} => "Fraction input should be an integer, float, string or Fraction object",
582587
Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
583588
Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
584589
}

src/input/input_abstract.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ pub trait Input<'py>: fmt::Debug {
115115

116116
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;
117117

118+
fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;
119+
118120
type Dict<'a>: ValidatedDict<'py>
119121
where
120122
Self: 'a;

src/input/input_json.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::input::return_enums::EitherComplex;
1212
use crate::lookup_key::{LookupKey, LookupPath};
1313
use crate::validators::complex::string_to_complex;
1414
use crate::validators::decimal::create_decimal;
15+
use crate::validators::fraction::create_fraction;
1516
use crate::validators::{TemporalUnitMode, ValBytesMode};
1617

1718
use super::datetime::{
@@ -199,6 +200,15 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
199200
}
200201
}
201202

203+
fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
204+
match self {
205+
JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => {
206+
create_fraction(&self.into_pyobject(py)?, self).map(ValidationMatch::strict)
207+
}
208+
_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
209+
}
210+
}
211+
202212
type Dict<'a>
203213
= &'a JsonObject<'data>
204214
where
@@ -454,6 +464,10 @@ impl<'py> Input<'py> for str {
454464
create_decimal(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax)
455465
}
456466

467+
fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
468+
create_fraction(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax)
469+
}
470+
457471
type Dict<'a> = Never;
458472

459473
#[cfg_attr(has_coverage_attribute, coverage(off))]

src/input/input_python.rs

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError,
1818
use crate::tools::{extract_i64, safe_repr};
1919
use crate::validators::complex::string_to_complex;
2020
use crate::validators::decimal::{create_decimal, get_decimal_type};
21+
use crate::validators::fraction::{create_fraction, get_fraction_type};
2122
use crate::validators::Exactness;
2223
use crate::validators::TemporalUnitMode;
2324
use crate::validators::ValBytesMode;
@@ -50,18 +51,6 @@ use super::{
5051

5152
static FRACTION_TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
5253

53-
pub fn get_fraction_type(py: Python<'_>) -> &Bound<'_, PyType> {
54-
FRACTION_TYPE
55-
.get_or_init(py, || {
56-
py.import("fractions")
57-
.and_then(|fractions_module| fractions_module.getattr("Fraction"))
58-
.unwrap()
59-
.extract()
60-
.unwrap()
61-
})
62-
.bind(py)
63-
}
64-
6554
pub(crate) fn downcast_python_input<'py, T: PyTypeCheck>(input: &(impl Input<'py> + ?Sized)) -> Option<&Bound<'py, T>> {
6655
input.as_python().and_then(|any| any.downcast::<T>().ok())
6756
}
@@ -70,6 +59,7 @@ pub(crate) fn input_as_python_instance<'a, 'py>(
7059
input: &'a (impl Input<'py> + ?Sized),
7160
class: &Bound<'py, PyType>,
7261
) -> Option<&'a Bound<'py, PyAny>> {
62+
println!("input_as_python_instance: class={:?}", class);
7363
input.as_python().filter(|any| any.is_instance(class).unwrap_or(false))
7464
}
7565

@@ -168,6 +158,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
168158
strict: bool,
169159
coerce_numbers_to_str: bool,
170160
) -> ValResult<ValidationMatch<EitherString<'_, 'py>>> {
161+
println!("[RUST]: Call validate_str with {:?}, and strict {:?}", self, strict);
171162
if let Ok(py_str) = self.downcast_exact::<PyString>() {
172163
return Ok(ValidationMatch::exact(py_str.clone().into()));
173164
} else if let Ok(py_str) = self.downcast::<PyString>() {
@@ -284,13 +275,14 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
284275

285276
'lax: {
286277
if !strict {
278+
println!("[RUST]: validate_int lax path for {:?}", self);
287279
return if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? {
288280
str_as_int(self, s)
289281
} else if self.is_exact_instance_of::<PyFloat>() {
290282
float_as_int(self, self.extract::<f64>()?)
291283
} else if let Ok(decimal) = self.validate_decimal(true, self.py()) {
292284
decimal_as_int(self, &decimal.into_inner())
293-
} else if self.is_instance(get_fraction_type(self.py()))? {
285+
} else if let Ok(fraction) = self.validate_fraction(true, self.py()) {
294286
fraction_as_int(self)
295287
} else if let Ok(float) = self.extract::<f64>() {
296288
float_as_int(self, float)
@@ -349,7 +341,49 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
349341
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
350342
}
351343

344+
fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
345+
println!("[RUST]: Call validate_fraction with {:?}, and strict {:?}", self, strict);
346+
let fraction_type = get_fraction_type(py);
347+
348+
// Fast path for existing decimal objects
349+
if self.is_exact_instance(fraction_type) {
350+
return Ok(ValidationMatch::exact(self.to_owned().clone()));
351+
}
352+
353+
if !strict {
354+
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>())
355+
{
356+
// Checking isinstance for str / int / bool is fast compared to decimal / float
357+
return create_fraction(self, self).map(ValidationMatch::lax);
358+
}
359+
360+
if self.is_instance_of::<PyFloat>() {
361+
return create_fraction(self.str()?.as_any(), self).map(ValidationMatch::lax);
362+
}
363+
}
364+
365+
if self.is_instance(fraction_type)? {
366+
// Upcast subclasses to decimal
367+
return create_decimal(self, self).map(ValidationMatch::strict);
368+
}
369+
370+
let error_type = if strict {
371+
ErrorType::IsInstanceOf {
372+
class: fraction_type
373+
.qualname()
374+
.and_then(|name| name.extract())
375+
.unwrap_or_else(|_| "Decimal".to_owned()),
376+
context: None,
377+
}
378+
} else {
379+
ErrorTypeDefaults::FractionType
380+
};
381+
382+
Err(ValError::new(error_type, self))
383+
}
384+
352385
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
386+
println!("[RUST]: Call validate_decimal with {:?}, and strict {:?}", self, strict);
353387
let decimal_type = get_decimal_type(py);
354388

355389
// Fast path for existing decimal objects

src/input/input_string.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath};
99
use crate::tools::safe_repr;
1010
use crate::validators::complex::string_to_complex;
1111
use crate::validators::decimal::create_decimal;
12+
use crate::validators::fraction::create_fraction;
1213
use crate::validators::{TemporalUnitMode, ValBytesMode};
1314

1415
use super::datetime::{
@@ -154,6 +155,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
154155
}
155156
}
156157

158+
fn validate_fraction(&self, _strict: bool, _py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
159+
match self {
160+
Self::String(s) => create_fraction(s, self).map(ValidationMatch::strict),
161+
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
162+
}
163+
}
164+
157165
type Dict<'a>
158166
= StringMappingDict<'py>
159167
where

0 commit comments

Comments
 (0)