Skip to content

Commit cd04138

Browse files
Refactor BigInt boilerplate (#1421)
1 parent d0384c7 commit cd04138

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

src/input/return_enums.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::errors::{
2020
py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, ValLineError, ValResult,
2121
};
2222
use crate::py_gc::PyGcTraverse;
23-
use crate::tools::{extract_i64, new_py_string, py_err};
23+
use crate::tools::{extract_i64, extract_int, new_py_string, py_err};
2424
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};
2525

2626
use super::{py_error_on_minusone, BorrowInput, Input};
@@ -662,6 +662,15 @@ pub enum Int {
662662
Big(BigInt),
663663
}
664664

665+
impl IntoPy<PyObject> for Int {
666+
fn into_py(self, py: Python<'_>) -> PyObject {
667+
match self {
668+
Self::I64(i) => i.into_py(py),
669+
Self::Big(big_i) => big_i.into_py(py),
670+
}
671+
}
672+
}
673+
665674
// The default serialization for BigInt is some internal representation which roundtrips efficiently
666675
// but is not the JSON value which users would expect to see.
667676
fn serialize_bigint_as_number<S>(big_int: &BigInt, serializer: S) -> Result<S::Ok, S::Error>
@@ -706,12 +715,9 @@ impl<'a> Rem for &'a Int {
706715

707716
impl FromPyObject<'_> for Int {
708717
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
709-
if let Some(i) = extract_i64(obj) {
710-
Ok(Int::I64(i))
711-
} else if let Ok(b) = obj.extract::<BigInt>() {
712-
Ok(Int::Big(b))
713-
} else {
714-
py_err!(PyTypeError; "Expected int, got {}", obj.get_type())
718+
match extract_int(obj) {
719+
Some(i) => Ok(i),
720+
None => py_err!(PyTypeError; "Expected int, got {}", obj.get_type()),
715721
}
716722
}
717723
}

src/serializers/infer.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::borrow::Cow;
22

3-
use num_bigint::BigInt;
43
use pyo3::exceptions::PyTypeError;
54
use pyo3::intern;
65
use pyo3::prelude::*;
@@ -12,7 +11,7 @@ use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
1211

1312
use crate::input::{EitherTimedelta, Int};
1413
use crate::serializers::type_serializers;
15-
use crate::tools::{extract_i64, py_err, safe_repr};
14+
use crate::tools::{extract_int, py_err, safe_repr};
1615
use crate::url::{PyMultiHostUrl, PyUrl};
1716

1817
use super::config::InfNanMode;
@@ -118,10 +117,8 @@ pub(crate) fn infer_to_python_known(
118117
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
119118
// have to do this to make sure subclasses of for example str are upcast to `str`
120119
ObType::IntSubclass => {
121-
if let Some(v) = extract_i64(value) {
122-
v.into_py(py)
123-
} else if let Ok(b) = value.extract::<BigInt>() {
124-
b.into_py(py)
120+
if let Some(i) = extract_int(value) {
121+
i.into_py(py)
125122
} else {
126123
return py_err!(PyTypeError; "Expected int, got {}", safe_repr(value));
127124
}

src/tools.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
use core::fmt;
22

3+
use num_bigint::BigInt;
4+
35
use pyo3::exceptions::PyKeyError;
46
use pyo3::prelude::*;
57
use pyo3::types::{PyDict, PyString};
68
use pyo3::{intern, FromPyObject};
79

10+
use crate::input::Int;
811
use jiter::{cached_py_string, pystring_fast_new, StringCacheMode};
912

1013
pub trait SchemaDict<'py> {
@@ -133,10 +136,15 @@ pub fn extract_i64(v: &Bound<'_, PyAny>) -> Option<i64> {
133136
// Can remove this after PyPy 7.3.17 is released
134137
return None;
135138
}
136-
137139
v.extract().ok()
138140
}
139141

142+
pub fn extract_int(v: &Bound<'_, PyAny>) -> Option<Int> {
143+
extract_i64(v)
144+
.map(Int::I64)
145+
.or_else(|| v.extract::<BigInt>().ok().map(Int::Big))
146+
}
147+
140148
pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCacheMode) -> Bound<'py, PyString> {
141149
// we could use `bytecount::num_chars(s.as_bytes()) == s.len()` as orjson does, but it doesn't appear to be faster
142150
let ascii_only = false;

0 commit comments

Comments
 (0)