Skip to content

Commit 9abadb7

Browse files
authored
replace cached_py_string and pystring_fast_new with safe alternatives (#201)
1 parent 37d62dd commit 9abadb7

File tree

5 files changed

+147
-78
lines changed

5 files changed

+147
-78
lines changed

crates/jiter/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ pub use value::{JsonArray, JsonObject, JsonValue};
172172
#[cfg(feature = "python")]
173173
pub use py_lossless_float::{FloatMode, LosslessFloat};
174174
#[cfg(feature = "python")]
175-
pub use py_string_cache::{cache_clear, cache_usage, cached_py_string, pystring_fast_new, StringCacheMode};
175+
pub use py_string_cache::{
176+
cache_clear, cache_usage, cached_py_string, cached_py_string_ascii, pystring_ascii_new, StringCacheMode,
177+
};
176178
#[cfg(feature = "python")]
177179
pub use python::{map_json_error, PythonParse};
178180

crates/jiter/src/py_string_cache.rs

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use pyo3::exceptions::{PyTypeError, PyValueError};
55
use pyo3::prelude::*;
66
use pyo3::types::{PyBool, PyString};
77

8+
use crate::string_decoder::StringOutput;
9+
810
#[derive(Debug, Clone, Copy)]
911
pub enum StringCacheMode {
1012
All,
@@ -50,38 +52,40 @@ impl From<bool> for StringCacheMode {
5052
}
5153

5254
pub trait StringMaybeCache {
53-
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString>;
55+
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString>;
5456

55-
fn get_value<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
56-
Self::get_key(py, json_str, ascii_only)
57+
fn get_value<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
58+
Self::get_key(py, string_output)
5759
}
5860
}
5961

6062
pub struct StringCacheAll;
6163

6264
impl StringMaybeCache for StringCacheAll {
63-
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
64-
cached_py_string(py, json_str, ascii_only)
65+
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
66+
// Safety: string_output carries the safety information
67+
unsafe { cached_py_string_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
6568
}
6669
}
6770

6871
pub struct StringCacheKeys;
6972

7073
impl StringMaybeCache for StringCacheKeys {
71-
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
72-
cached_py_string(py, json_str, ascii_only)
74+
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
75+
// Safety: string_output carries the safety information
76+
unsafe { cached_py_string_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
7377
}
7478

75-
fn get_value<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
76-
pystring_fast_new(py, json_str, ascii_only)
79+
fn get_value<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
80+
unsafe { pystring_fast_new_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
7781
}
7882
}
7983

8084
pub struct StringNoCache;
8185

8286
impl StringMaybeCache for StringNoCache {
83-
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
84-
pystring_fast_new(py, json_str, ascii_only)
87+
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
88+
unsafe { pystring_fast_new_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
8589
}
8690
}
8791

@@ -108,12 +112,33 @@ pub fn cache_clear() {
108112
get_string_cache().clear();
109113
}
110114

111-
pub fn cached_py_string<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
115+
/// Create a cached Python `str` from a string slice
116+
#[inline]
117+
pub fn cached_py_string<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
118+
// SAFETY: not setting ascii-only
119+
unsafe { cached_py_string_maybe_ascii(py, s, false) }
120+
}
121+
122+
/// Create a cached Python `str` from a string slice.
123+
///
124+
/// # Safety
125+
///
126+
/// Caller must pass ascii-only string.
127+
#[inline]
128+
pub unsafe fn cached_py_string_ascii<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
129+
// SAFETY: caller upholds invariant
130+
unsafe { cached_py_string_maybe_ascii(py, s, true) }
131+
}
132+
133+
/// # Safety
134+
///
135+
/// Caller must match the ascii_only flag to the string passed in.
136+
unsafe fn cached_py_string_maybe_ascii<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
112137
// from tests, 0 and 1 character strings are faster not cached
113138
if (2..64).contains(&s.len()) {
114139
get_string_cache().get_or_insert(py, s, ascii_only)
115140
} else {
116-
pystring_fast_new(py, s, ascii_only)
141+
pystring_fast_new_maybe_ascii(py, s, ascii_only)
117142
}
118143
}
119144

@@ -146,13 +171,18 @@ impl Default for PyStringCache {
146171
impl PyStringCache {
147172
/// Lookup the cache for an entry with the given string. If it exists, return it.
148173
/// If it is not set or has a different string, insert it and return it.
149-
fn get_or_insert<'py>(&mut self, py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
174+
///
175+
/// # Safety
176+
///
177+
/// `ascii_only` must only be set to `true` if the string is guaranteed to be ASCII only.
178+
unsafe fn get_or_insert<'py>(&mut self, py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
150179
let hash = self.hash_builder.hash_one(s);
151180

152181
let hash_index = hash as usize % CAPACITY;
153182

154183
let set_entry = |entry: &mut Entry| {
155-
let py_str = pystring_fast_new(py, s, ascii_only);
184+
// SAFETY: caller upholds invariant
185+
let py_str = unsafe { pystring_fast_new_maybe_ascii(py, s, ascii_only) };
156186
if let Some((_, old_py_str)) = entry.replace((hash, py_str.clone().unbind())) {
157187
// micro-optimization: bind the old entry before dropping it so that PyO3 can
158188
// fast-path the drop (Bound::drop is faster than Py::drop)
@@ -199,8 +229,14 @@ impl PyStringCache {
199229
}
200230
}
201231

202-
pub fn pystring_fast_new<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
232+
/// Creatate a new Python `str` from a string slice, with a fast path for ASCII strings
233+
///
234+
/// # Safety
235+
///
236+
/// `ascii_only` must only be set to `true` if the string is guaranteed to be ASCII only.
237+
unsafe fn pystring_fast_new_maybe_ascii<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
203238
if ascii_only {
239+
// SAFETY: caller upholds invariant
204240
unsafe { pystring_ascii_new(py, s) }
205241
} else {
206242
PyString::new(py, s)
@@ -209,22 +245,24 @@ pub fn pystring_fast_new<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bou
209245

210246
/// Faster creation of PyString from an ASCII string, inspired by
211247
/// https://github.com/ijl/orjson/blob/3.10.0/src/str/create.rs#L41
212-
#[cfg(not(any(PyPy, GraalPy)))]
213-
unsafe fn pystring_ascii_new<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
214-
// disabled on everything except tier-1 platforms because of a crash in the built wheels from CI,
215-
// see https://github.com/pydantic/jiter/pull/175
216-
217-
let ptr = pyo3::ffi::PyUnicode_New(s.len() as isize, 127);
218-
// see https://github.com/pydantic/jiter/pull/72#discussion_r1545485907
219-
debug_assert_eq!(pyo3::ffi::PyUnicode_KIND(ptr), pyo3::ffi::PyUnicode_1BYTE_KIND);
220-
let data_ptr = pyo3::ffi::PyUnicode_DATA(ptr).cast();
221-
core::ptr::copy_nonoverlapping(s.as_ptr(), data_ptr, s.len());
222-
core::ptr::write(data_ptr.add(s.len()), 0);
223-
Bound::from_owned_ptr(py, ptr).downcast_into_unchecked()
224-
}
225-
226-
// unoptimized version (albeit not that much slower) on other platforms
227-
#[cfg(any(PyPy, GraalPy))]
228-
unsafe fn pystring_ascii_new<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
229-
PyString::new(py, s)
248+
///
249+
/// # Safety
250+
///
251+
/// `s` must be ASCII only
252+
pub unsafe fn pystring_ascii_new<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
253+
#[cfg(not(any(PyPy, GraalPy, Py_LIMITED_API)))]
254+
{
255+
let ptr = pyo3::ffi::PyUnicode_New(s.len() as isize, 127);
256+
// see https://github.com/pydantic/jiter/pull/72#discussion_r1545485907
257+
debug_assert_eq!(pyo3::ffi::PyUnicode_KIND(ptr), pyo3::ffi::PyUnicode_1BYTE_KIND);
258+
let data_ptr = pyo3::ffi::PyUnicode_DATA(ptr).cast();
259+
core::ptr::copy_nonoverlapping(s.as_ptr(), data_ptr, s.len());
260+
core::ptr::write(data_ptr.add(s.len()), 0);
261+
Bound::from_owned_ptr(py, ptr).downcast_into_unchecked()
262+
}
263+
264+
#[cfg(any(PyPy, GraalPy, Py_LIMITED_API))]
265+
{
266+
PyString::new(py, s)
267+
}
230268
}

crates/jiter/src/python.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ impl<StringCache: StringMaybeCache, KeyCheck: MaybeKeyCheck, ParseNumber: MaybeP
138138
let s = self
139139
.parser
140140
.consume_string::<StringDecoder>(&mut self.tape, self.partial_mode.allow_trailing_str())?;
141-
Ok(StringCache::get_value(py, s.as_str(), s.ascii_only()).into_any())
141+
Ok(StringCache::get_value(py, s).into_any())
142142
}
143143
Peek::Array => {
144144
let peek_first = match self.parser.array_first() {
@@ -198,14 +198,14 @@ impl<StringCache: StringMaybeCache, KeyCheck: MaybeKeyCheck, ParseNumber: MaybeP
198198
if let Some(first_key) = self.parser.object_first::<StringDecoder>(&mut self.tape)? {
199199
let first_key_s = first_key.as_str();
200200
check_keys.check(first_key_s, self.parser.index)?;
201-
let first_key = StringCache::get_key(py, first_key_s, first_key.ascii_only());
201+
let first_key = StringCache::get_key(py, first_key);
202202
let peek = self.parser.peek()?;
203203
let first_value = self.check_take_value(py, peek)?;
204204
set_item(first_key, first_value);
205205
while let Some(key) = self.parser.object_step::<StringDecoder>(&mut self.tape)? {
206206
let key_s = key.as_str();
207207
check_keys.check(key_s, self.parser.index)?;
208-
let key = StringCache::get_key(py, key_s, key.ascii_only());
208+
let key = StringCache::get_key(py, key);
209209
let peek = self.parser.peek()?;
210210
let value = self.check_take_value(py, peek)?;
211211
set_item(key, value);

crates/jiter/src/string_decoder.rs

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::borrow::Cow;
21
use std::ops::Range;
32
use std::str::{from_utf8, from_utf8_unchecked};
43

@@ -26,48 +25,87 @@ where
2625
pub struct StringDecoder;
2726

2827
#[derive(Debug)]
29-
pub enum StringOutput<'t, 'j>
28+
pub enum StringOutputType<'t, 'j>
3029
where
3130
'j: 't,
3231
{
33-
Tape(&'t str, bool),
34-
Data(&'j str, bool),
32+
Tape(&'t str),
33+
Data(&'j str),
3534
}
3635

37-
impl From<StringOutput<'_, '_>> for String {
38-
fn from(val: StringOutput) -> Self {
39-
match val {
40-
StringOutput::Tape(s, _) => s.to_owned(),
41-
StringOutput::Data(s, _) => s.to_owned(),
42-
}
36+
/// This submodule is used to create a safety boundary where the `ascii_only`
37+
/// flag can be used to carry soundness information about the string.
38+
mod string_output {
39+
use std::borrow::Cow;
40+
41+
use super::StringOutputType;
42+
43+
#[derive(Debug)]
44+
pub struct StringOutput<'t, 'j>
45+
where
46+
'j: 't,
47+
{
48+
data: StringOutputType<'t, 'j>,
49+
// SAFETY: this is used as an invariant to determine if the string is ascii only
50+
// so this should not be set except when known
51+
ascii_only: bool,
4352
}
44-
}
4553

46-
impl<'j> From<StringOutput<'_, 'j>> for Cow<'j, str> {
47-
fn from(val: StringOutput<'_, 'j>) -> Self {
48-
match val {
49-
StringOutput::Tape(s, _) => s.to_owned().into(),
50-
StringOutput::Data(s, _) => s.into(),
54+
impl From<StringOutput<'_, '_>> for String {
55+
fn from(val: StringOutput) -> Self {
56+
match val.data {
57+
StringOutputType::Tape(s) | StringOutputType::Data(s) => s.to_owned(),
58+
}
5159
}
5260
}
53-
}
5461

55-
impl<'t> StringOutput<'t, '_> {
56-
pub fn as_str(&self) -> &'t str {
57-
match self {
58-
Self::Tape(s, _) => s,
59-
Self::Data(s, _) => s,
62+
impl<'j> From<StringOutput<'_, 'j>> for Cow<'j, str> {
63+
fn from(val: StringOutput<'_, 'j>) -> Self {
64+
match val.data {
65+
StringOutputType::Tape(s) => s.to_owned().into(),
66+
StringOutputType::Data(s) => s.into(),
67+
}
6068
}
6169
}
6270

63-
pub fn ascii_only(&self) -> bool {
64-
match self {
65-
Self::Tape(_, ascii_only) => *ascii_only,
66-
Self::Data(_, ascii_only) => *ascii_only,
71+
impl<'t, 'j> StringOutput<'t, 'j>
72+
where
73+
'j: 't,
74+
{
75+
/// # Safety
76+
///
77+
/// `accii_only` must only be set to true if the string is ascii only
78+
pub unsafe fn tape(data: &'t str, ascii_only: bool) -> Self {
79+
StringOutput {
80+
data: StringOutputType::Tape(data),
81+
ascii_only,
82+
}
83+
}
84+
85+
/// # Safety
86+
///
87+
/// `accii_only` must only be set to true if the string is ascii only
88+
pub unsafe fn data(data: &'j str, ascii_only: bool) -> Self {
89+
StringOutput {
90+
data: StringOutputType::Data(data),
91+
ascii_only,
92+
}
93+
}
94+
95+
pub fn as_str(&self) -> &'t str {
96+
match self.data {
97+
StringOutputType::Tape(s) | StringOutputType::Data(s) => s,
98+
}
99+
}
100+
101+
pub fn ascii_only(&self) -> bool {
102+
self.ascii_only
67103
}
68104
}
69105
}
70106

107+
pub use string_output::StringOutput;
108+
71109
impl<'t, 'j> AbstractStringDecoder<'t, 'j> for StringDecoder
72110
where
73111
'j: 't,
@@ -85,7 +123,7 @@ where
85123
match decode_chunk(data, start, true, allow_partial)? {
86124
(StringChunk::StringEnd, ascii_only, index) => {
87125
let s = to_str(&data[start..index], ascii_only, start)?;
88-
Ok((StringOutput::Data(s, ascii_only), index + 1))
126+
Ok((unsafe { StringOutput::data(s, ascii_only) }, index + 1))
89127
}
90128
(StringChunk::Backslash, ascii_only, index) => {
91129
decode_to_tape(data, index, tape, start, ascii_only, allow_partial)
@@ -134,7 +172,7 @@ fn decode_to_tape<'t, 'j>(
134172
tape.extend_from_slice(&data[index..new_index]);
135173
index = new_index + 1;
136174
let s = to_str(tape, ascii_only, start)?;
137-
return Ok((StringOutput::Tape(s, ascii_only), index));
175+
return Ok((unsafe { StringOutput::tape(s, ascii_only) }, index));
138176
}
139177
(StringChunk::Backslash, ascii_only_new, index_new) => {
140178
ascii_only = ascii_only_new;

crates/jiter/tests/python.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use pyo3::prelude::*;
22
use pyo3::types::PyString;
33

4-
use jiter::{pystring_fast_new, JsonValue, PythonParse, StringCacheMode};
4+
use jiter::{pystring_ascii_new, JsonValue, PythonParse, StringCacheMode};
55

66
#[cfg(feature = "num-bigint")]
77
#[test]
@@ -71,19 +71,10 @@ fn test_cache_into() {
7171
}
7272

7373
#[test]
74-
fn test_pystring_fast_new_non_ascii() {
75-
let json = "£100 💩";
76-
Python::with_gil(|py| {
77-
let s = pystring_fast_new(py, json, false);
78-
assert_eq!(s.to_string(), "£100 💩");
79-
});
80-
}
81-
82-
#[test]
83-
fn test_pystring_fast_new_ascii() {
74+
fn test_pystring_ascii_new() {
8475
let json = "100abc";
8576
Python::with_gil(|py| {
86-
let s = pystring_fast_new(py, json, true);
77+
let s = unsafe { pystring_ascii_new(py, json) };
8778
assert_eq!(s.to_string(), "100abc");
8879
});
8980
}

0 commit comments

Comments
 (0)