Skip to content

Commit 3d8295e

Browse files
authored
Support complex numbers (#1331)
1 parent bb67044 commit 3d8295e

File tree

20 files changed

+601
-8
lines changed

20 files changed

+601
-8
lines changed

generate_self_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901
5151
if isinstance(obj, str):
5252
return {'type': obj}
53-
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
53+
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal, complex):
5454
return {'type': obj.__name__.lower()}
5555
elif is_typeddict(obj):
5656
return type_dict_schema(obj, definitions)

python/pydantic_core/core_schema.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,48 @@ def decimal_schema(
745745
)
746746

747747

748+
class ComplexSchema(TypedDict, total=False):
749+
type: Required[Literal['complex']]
750+
strict: bool
751+
ref: str
752+
metadata: Any
753+
serialization: SerSchema
754+
755+
756+
def complex_schema(
757+
*,
758+
strict: bool | None = None,
759+
ref: str | None = None,
760+
metadata: Any = None,
761+
serialization: SerSchema | None = None,
762+
) -> ComplexSchema:
763+
"""
764+
Returns a schema that matches a complex value, e.g.:
765+
766+
```py
767+
from pydantic_core import SchemaValidator, core_schema
768+
769+
schema = core_schema.complex_schema()
770+
v = SchemaValidator(schema)
771+
assert v.validate_python('1+2j') == complex(1, 2)
772+
assert v.validate_python(complex(1, 2)) == complex(1, 2)
773+
```
774+
775+
Args:
776+
strict: Whether the value should be a complex object instance or a value that can be converted to a complex object
777+
ref: optional unique identifier of the schema, used to reference the schema in other places
778+
metadata: Any other information you want to include with the schema, not used by pydantic-core
779+
serialization: Custom serialization schema
780+
"""
781+
return _dict_not_none(
782+
type='complex',
783+
strict=strict,
784+
ref=ref,
785+
metadata=metadata,
786+
serialization=serialization,
787+
)
788+
789+
748790
class StringSchema(TypedDict, total=False):
749791
type: Required[Literal['str']]
750792
pattern: Union[str, Pattern[str]]
@@ -3796,6 +3838,7 @@ def definition_reference_schema(
37963838
DefinitionsSchema,
37973839
DefinitionReferenceSchema,
37983840
UuidSchema,
3841+
ComplexSchema,
37993842
]
38003843
elif False:
38013844
CoreSchema: TypeAlias = Mapping[str, Any]
@@ -3851,6 +3894,7 @@ def definition_reference_schema(
38513894
'definitions',
38523895
'definition-ref',
38533896
'uuid',
3897+
'complex',
38543898
]
38553899

38563900
CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
@@ -3956,6 +4000,8 @@ def definition_reference_schema(
39564000
'decimal_max_digits',
39574001
'decimal_max_places',
39584002
'decimal_whole_digits',
4003+
'complex_type',
4004+
'complex_str_parsing',
39594005
]
39604006

39614007

src/errors/types.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ error_types! {
426426
DecimalWholeDigits {
427427
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
428428
},
429+
// Complex errors
430+
ComplexType {},
431+
ComplexStrParsing {},
429432
}
430433

431434
macro_rules! render {
@@ -569,6 +572,8 @@ impl ErrorType {
569572
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
570573
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
571574
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
575+
Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
576+
Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
572577
}
573578
}
574579

src/input/input_abstract.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::tools::py_err;
1010
use crate::validators::ValBytesMode;
1111

1212
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
13-
use super::return_enums::{EitherBytes, EitherInt, EitherString};
13+
use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString};
1414
use super::{EitherFloat, GenericIterator, ValidationMatch};
1515

1616
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@@ -173,6 +173,8 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
173173
strict: bool,
174174
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
175175
) -> ValMatch<EitherTimedelta<'py>>;
176+
177+
fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValMatch<EitherComplex<'py>>;
176178
}
177179

178180
/// The problem to solve here is that iterating collections often returns owned

src/input/input_json.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;
88
use strum::EnumMessage;
99

1010
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
11+
use crate::input::return_enums::EitherComplex;
1112
use crate::lookup_key::{LookupKey, LookupPath};
13+
use crate::validators::complex::string_to_complex;
1214
use crate::validators::decimal::create_decimal;
1315
use crate::validators::ValBytesMode;
1416

@@ -304,6 +306,30 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
304306
_ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
305307
}
306308
}
309+
310+
fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
311+
match self {
312+
JsonValue::Str(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(
313+
&PyString::new_bound(py, s),
314+
self,
315+
)?))),
316+
JsonValue::Float(f) => {
317+
if !strict {
318+
Ok(ValidationMatch::lax(EitherComplex::Complex([*f, 0.0])))
319+
} else {
320+
Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self))
321+
}
322+
}
323+
JsonValue::Int(f) => {
324+
if !strict {
325+
Ok(ValidationMatch::lax(EitherComplex::Complex([(*f) as f64, 0.0])))
326+
} else {
327+
Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self))
328+
}
329+
}
330+
_ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
331+
}
332+
}
307333
}
308334

309335
/// Required for JSON Object keys so the string can behave like an Input
@@ -440,6 +466,13 @@ impl<'py> Input<'py> for str {
440466
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>> {
441467
bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax)
442468
}
469+
470+
fn validate_complex(&self, _strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
471+
Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(
472+
self.to_object(py).downcast_bound::<PyString>(py)?,
473+
self,
474+
)?)))
475+
}
443476
}
444477

445478
impl BorrowInput<'_> for &'_ String {

src/input/input_python.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@ use pyo3::prelude::*;
55

66
use pyo3::types::PyType;
77
use pyo3::types::{
8-
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
9-
PyMapping, PySet, PyString, PyTime, PyTuple,
8+
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
9+
PyList, PyMapping, PySet, PyString, PyTime, PyTuple,
1010
};
1111

1212
use pyo3::PyTypeCheck;
13+
use pyo3::PyTypeInfo;
1314
use speedate::MicrosecondsPrecisionOverflowBehavior;
1415

1516
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
1617
use crate::tools::{extract_i64, safe_repr};
18+
use crate::validators::complex::string_to_complex;
1719
use crate::validators::decimal::{create_decimal, get_decimal_type};
1820
use crate::validators::Exactness;
1921
use crate::validators::ValBytesMode;
@@ -25,6 +27,7 @@ use super::datetime::{
2527
EitherTime,
2628
};
2729
use super::input_abstract::ValMatch;
30+
use super::return_enums::EitherComplex;
2831
use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch};
2932
use super::shared::{
3033
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
@@ -598,6 +601,45 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
598601

599602
Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self))
600603
}
604+
605+
fn validate_complex<'a>(&'a self, strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
606+
if let Ok(complex) = self.downcast::<PyComplex>() {
607+
return Ok(ValidationMatch::strict(EitherComplex::Py(complex.to_owned())));
608+
}
609+
if strict {
610+
return Err(ValError::new(
611+
ErrorType::IsInstanceOf {
612+
class: PyComplex::type_object_bound(py)
613+
.qualname()
614+
.and_then(|name| name.extract())
615+
.unwrap_or_else(|_| "complex".to_owned()),
616+
context: None,
617+
},
618+
self,
619+
));
620+
}
621+
622+
if let Ok(s) = self.downcast::<PyString>() {
623+
// If input is not a valid complex string, instead of telling users to correct
624+
// the string, it makes more sense to tell them to provide any acceptable value
625+
// since they might have just given values of some incorrect types instead
626+
// of actually trying some complex strings.
627+
if let Ok(c) = string_to_complex(s, self) {
628+
return Ok(ValidationMatch::lax(EitherComplex::Py(c)));
629+
}
630+
} else if self.is_exact_instance_of::<PyFloat>() {
631+
return Ok(ValidationMatch::lax(EitherComplex::Complex([
632+
self.extract::<f64>().unwrap(),
633+
0.0,
634+
])));
635+
} else if self.is_exact_instance_of::<PyInt>() {
636+
return Ok(ValidationMatch::lax(EitherComplex::Complex([
637+
self.extract::<i64>().unwrap() as f64,
638+
0.0,
639+
])));
640+
}
641+
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
642+
}
601643
}
602644

603645
impl<'py> BorrowInput<'py> for Bound<'py, PyAny> {

src/input/input_string.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}
77
use crate::input::py_string_str;
88
use crate::lookup_key::{LookupKey, LookupPath};
99
use crate::tools::safe_repr;
10+
use crate::validators::complex::string_to_complex;
1011
use crate::validators::decimal::create_decimal;
1112
use crate::validators::ValBytesMode;
1213

1314
use super::datetime::{
1415
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
1516
};
1617
use super::input_abstract::{Never, ValMatch};
18+
use super::return_enums::EitherComplex;
1719
use super::shared::{str_as_bool, str_as_float, str_as_int};
1820
use super::{
1921
Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input,
@@ -225,6 +227,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
225227
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
226228
}
227229
}
230+
231+
fn validate_complex(&self, _strict: bool, _py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
232+
match self {
233+
Self::String(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(s, self)?))),
234+
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
235+
}
236+
}
228237
}
229238

230239
impl<'py> BorrowInput<'py> for StringMapping<'py> {

src/input/return_enums.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use pyo3::intern;
1212
use pyo3::prelude::*;
1313
#[cfg(not(PyPy))]
1414
use pyo3::types::PyFunction;
15-
use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};
15+
use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};
1616

1717
use serde::{ser::Error, Serialize, Serializer};
1818

@@ -724,3 +724,30 @@ impl ToPyObject for Int {
724724
}
725725
}
726726
}
727+
728+
#[derive(Clone)]
729+
pub enum EitherComplex<'a> {
730+
Complex([f64; 2]),
731+
Py(Bound<'a, PyComplex>),
732+
}
733+
734+
impl<'a> IntoPy<PyObject> for EitherComplex<'a> {
735+
fn into_py(self, py: Python<'_>) -> PyObject {
736+
match self {
737+
Self::Complex(c) => PyComplex::from_doubles_bound(py, c[0], c[1]).into_py(py),
738+
Self::Py(c) => c.into_py(py),
739+
}
740+
}
741+
}
742+
743+
impl<'a> EitherComplex<'a> {
744+
pub fn as_f64(&self, py: Python<'_>) -> [f64; 2] {
745+
match self {
746+
EitherComplex::Complex(f) => *f,
747+
EitherComplex::Py(f) => [
748+
f.getattr(intern!(py, "real")).unwrap().extract().unwrap(),
749+
f.getattr(intern!(py, "imag")).unwrap().extract().unwrap(),
750+
],
751+
}
752+
}
753+
}

src/serializers/infer.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use pyo3::exceptions::PyTypeError;
44
use pyo3::intern;
55
use pyo3::prelude::*;
66
use pyo3::pybacked::PyBackedStr;
7+
use pyo3::types::PyComplex;
78
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};
89

910
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
@@ -226,6 +227,13 @@ pub(crate) fn infer_to_python_known(
226227
}
227228
PyList::new_bound(py, items).into_py(py)
228229
}
230+
ObType::Complex => {
231+
let dict = value.downcast::<PyDict>()?;
232+
let new_dict = PyDict::new_bound(py);
233+
let _ = new_dict.set_item("real", dict.get_item("real")?);
234+
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
235+
new_dict.into_py(py)
236+
}
229237
ObType::Path => value.str()?.into_py(py),
230238
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
231239
ObType::Unknown => {
@@ -274,6 +282,13 @@ pub(crate) fn infer_to_python_known(
274282
);
275283
iter.into_py(py)
276284
}
285+
ObType::Complex => {
286+
let dict = value.downcast::<PyDict>()?;
287+
let new_dict = PyDict::new_bound(py);
288+
let _ = new_dict.set_item("real", dict.get_item("real")?);
289+
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
290+
new_dict.into_py(py)
291+
}
277292
ObType::Unknown => {
278293
if let Some(fallback) = extra.fallback {
279294
let next_value = fallback.call1((value,))?;
@@ -402,6 +417,13 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
402417
ObType::None => serializer.serialize_none(),
403418
ObType::Int | ObType::IntSubclass => serialize!(Int),
404419
ObType::Bool => serialize!(bool),
420+
ObType::Complex => {
421+
let v = value.downcast::<PyComplex>().map_err(py_err_se_err)?;
422+
let mut map = serializer.serialize_map(Some(2))?;
423+
map.serialize_entry(&"real", &v.real())?;
424+
map.serialize_entry(&"imag", &v.imag())?;
425+
map.end()
426+
}
405427
ObType::Float | ObType::FloatSubclass => {
406428
let v = value.extract::<f64>().map_err(py_err_se_err)?;
407429
type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode)
@@ -647,7 +669,7 @@ pub(crate) fn infer_json_key_known<'a>(
647669
}
648670
Ok(Cow::Owned(key_build.finish()))
649671
}
650-
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
672+
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => {
651673
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
652674
}
653675
ObType::Dataclass | ObType::PydanticSerializable => {

0 commit comments

Comments
 (0)