Skip to content

Commit 1558aca

Browse files
committed
add an option to disallow nan
1 parent d6906a9 commit 1558aca

File tree

8 files changed

+19
-6
lines changed

8 files changed

+19
-6
lines changed

pysrc/orjson/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Fragment(tuple):
1717
contents: Union[bytes, str]
1818

1919
OPT_APPEND_NEWLINE: int
20+
OPT_DISALLOW_NAN: int
2021
OPT_INDENT_2: int
2122
OPT_NAIVE_UTC: int
2223
OPT_NON_STR_KEYS: int

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ pub unsafe extern "C" fn orjson_init_exec(mptr: *mut PyObject) -> c_int {
126126
add!(mptr, "Fragment\0", typeref::FRAGMENT_TYPE as *mut PyObject);
127127

128128
opt!(mptr, "OPT_APPEND_NEWLINE\0", opt::APPEND_NEWLINE);
129+
opt!(mptr, "OPT_DISALLOW_NAN\0", opt::DISALLOW_NAN);
129130
opt!(mptr, "OPT_INDENT_2\0", opt::INDENT_2);
130131
opt!(mptr, "OPT_NAIVE_UTC\0", opt::NAIVE_UTC);
131132
opt!(mptr, "OPT_NON_STR_KEYS\0", opt::NON_STR_KEYS);

src/opt.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub const PASSTHROUGH_SUBCLASS: Opt = 1 << 8;
1414
pub const PASSTHROUGH_DATETIME: Opt = 1 << 9;
1515
pub const APPEND_NEWLINE: Opt = 1 << 10;
1616
pub const PASSTHROUGH_DATACLASS: Opt = 1 << 11;
17+
pub const DISALLOW_NAN: Opt = 1 << 12;
1718

1819
// deprecated
1920
pub const SERIALIZE_DATACLASS: Opt = 0;
@@ -25,6 +26,7 @@ pub const NOT_PASSTHROUGH: Opt =
2526
!(PASSTHROUGH_DATETIME | PASSTHROUGH_DATACLASS | PASSTHROUGH_SUBCLASS);
2627

2728
pub const MAX_OPT: i32 = (APPEND_NEWLINE
29+
| DISALLOW_NAN
2830
| INDENT_2
2931
| NAIVE_UTC
3032
| NON_STR_KEYS

src/serialize/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use core::ptr::NonNull;
66
pub enum SerializeError {
77
DatetimeLibraryUnsupported,
88
DefaultRecursionLimit,
9+
FloatNotFinite,
910
Integer53Bits,
1011
Integer64Bits,
1112
InvalidStr,
@@ -35,6 +36,9 @@ impl std::fmt::Display for SerializeError {
3536
SerializeError::DefaultRecursionLimit => {
3637
write!(f, "default serializer exceeds recursion limit")
3738
}
39+
SerializeError::FloatNotFinite => {
40+
write!(f, "Cannot serialize Infinity or NaN")
41+
}
3842
SerializeError::Integer53Bits => write!(f, "Integer exceeds 53-bit range"),
3943
SerializeError::Integer64Bits => write!(f, "Integer exceeds 64-bit range"),
4044
SerializeError::InvalidStr => write!(f, "{}", crate::util::INVALID_STR),

src/serialize/per_type/dict.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ macro_rules! impl_serialize_entry {
110110
}
111111
ObType::Float => {
112112
$map.serialize_key($key).unwrap();
113-
$map.serialize_value(&FloatSerializer::new($value))?;
113+
$map.serialize_value(&FloatSerializer::new($value, $self.state.opts()))?;
114114
}
115115
ObType::Bool => {
116116
$map.serialize_key($key).unwrap();

src/serialize/per_type/float.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
22

3+
use crate::opt::{Opt, DISALLOW_NAN};
4+
use crate::serialize::error::SerializeError;
35
use serde::ser::{Serialize, Serializer};
46

5-
#[repr(transparent)]
67
pub struct FloatSerializer {
78
ptr: *mut pyo3_ffi::PyObject,
9+
opts: Opt,
810
}
911

1012
impl FloatSerializer {
11-
pub fn new(ptr: *mut pyo3_ffi::PyObject) -> Self {
12-
FloatSerializer { ptr: ptr }
13+
pub fn new(ptr: *mut pyo3_ffi::PyObject, opts: Opt) -> Self {
14+
FloatSerializer { ptr: ptr, opts: opts }
1315
}
1416
}
1517

@@ -22,6 +24,9 @@ impl Serialize for FloatSerializer {
2224
let value = ffi!(PyFloat_AS_DOUBLE(self.ptr));
2325
#[cfg(yyjson_allow_inf_and_nan)]
2426
{
27+
if unlikely!(opt_enabled!(self.opts, DISALLOW_NAN)) && !value.is_finite() {
28+
err!(SerializeError::FloatNotFinite)
29+
}
2530
serializer.serialize_f64(value)
2631
}
2732
#[cfg(not(yyjson_allow_inf_and_nan))]

src/serialize/per_type/list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ impl Serialize for ListTupleSerializer {
106106
seq.serialize_element(&NoneSerializer::new()).unwrap();
107107
}
108108
ObType::Float => {
109-
seq.serialize_element(&FloatSerializer::new(value))?;
109+
seq.serialize_element(&FloatSerializer::new(value, self.state.opts()))?;
110110
}
111111
ObType::Bool => {
112112
seq.serialize_element(&BoolSerializer::new(value)).unwrap();

src/serialize/serializer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl Serialize for PyObjectSerializer {
7070
ObType::StrSubclass => StrSubclassSerializer::new(self.ptr).serialize(serializer),
7171
ObType::Int => IntSerializer::new(self.ptr, self.state.opts()).serialize(serializer),
7272
ObType::None => NoneSerializer::new().serialize(serializer),
73-
ObType::Float => FloatSerializer::new(self.ptr).serialize(serializer),
73+
ObType::Float => FloatSerializer::new(self.ptr, self.state.opts()).serialize(serializer),
7474
ObType::Bool => BoolSerializer::new(self.ptr).serialize(serializer),
7575
ObType::Datetime => DateTime::new(self.ptr, self.state.opts()).serialize(serializer),
7676
ObType::Date => Date::new(self.ptr).serialize(serializer),

0 commit comments

Comments
 (0)