Skip to content

Commit 97586f4

Browse files
Copilotjoocer
andcommitted
Fix segfault issues in Cython JSONL decoder
Co-authored-by: joocer <[email protected]>
1 parent 806c64f commit 97586f4

File tree

1 file changed

+39
-21
lines changed

1 file changed

+39
-21
lines changed

opteryx/compiled/structures/jsonl_decoder.pyx

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ Fast JSONL decoder using Cython for performance-critical operations.
1212
This decoder uses native C string operations instead of regex for better performance.
1313
"""
1414

15-
from libc.string cimport memchr, strlen, strstr
15+
from libc.string cimport memchr, strlen, strstr, memcmp
1616
from libc.stdlib cimport strtod, strtol, atoi
1717
from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_GET_SIZE
1818
from libc.stdint cimport int64_t
1919

2020
import pyarrow
21-
from opteryx.third_party.tktech import csimdjson as simdjson
2221

2322

2423
cdef inline const char* find_key_value(const char* line, Py_ssize_t line_len, const char* key, Py_ssize_t key_len, Py_ssize_t* value_start, Py_ssize_t* value_len):
@@ -37,19 +36,23 @@ cdef inline const char* find_key_value(const char* line, Py_ssize_t line_len, co
3736
cdef char first_char
3837
cdef int brace_count
3938
cdef int bracket_count
39+
cdef Py_ssize_t remaining
4040

4141
# Search for the key pattern: "key":
4242
while pos < end:
4343
# Find opening quote of a key
44-
key_pos = <const char*>memchr(pos, b'"', end - pos)
44+
remaining = end - pos
45+
if remaining <= 0:
46+
return NULL
47+
key_pos = <const char*>memchr(pos, b'"', <size_t>remaining)
4548
if key_pos == NULL:
4649
return NULL
4750

4851
key_pos += 1 # Move past the opening quote
4952

5053
# Check if this matches our key
5154
if (end - key_pos >= key_len and
52-
memcmp(key_pos, key, key_len) == 0 and
55+
memcmp(key_pos, key, <size_t>key_len) == 0 and
5356
key_pos[key_len] == b'"'):
5457

5558
# Found the key, now find the colon
@@ -71,7 +74,13 @@ cdef inline const char* find_key_value(const char* line, Py_ssize_t line_len, co
7174
quote_start = value_pos + 1
7275
quote_end = quote_start
7376
while quote_end < end:
74-
if quote_end[0] == b'"' and (quote_end == quote_start or quote_end[-1] != b'\\'):
77+
if quote_end[0] == b'"':
78+
# Check if it's escaped (previous char is backslash)
79+
if quote_end > quote_start and quote_end[-1] == b'\\':
80+
# It's escaped, keep going
81+
quote_end += 1
82+
continue
83+
# Found unescaped quote
7584
value_len[0] = (quote_end + 1) - value_pos
7685
return value_pos
7786
quote_end += 1
@@ -128,7 +137,8 @@ cdef inline const char* find_key_value(const char* line, Py_ssize_t line_len, co
128137
# Number - find end (space, comma, brace, bracket)
129138
quote_end = value_pos + 1
130139
while quote_end < end:
131-
if quote_end[0] in (b' ', b',', b'}', b']', b'\t', b'\n'):
140+
# Check for delimiter characters
141+
if quote_end[0] == b' ' or quote_end[0] == b',' or quote_end[0] == b'}' or quote_end[0] == b']' or quote_end[0] == b'\t' or quote_end[0] == b'\n':
132142
break
133143
quote_end += 1
134144
value_len[0] = quote_end - value_pos
@@ -139,10 +149,6 @@ cdef inline const char* find_key_value(const char* line, Py_ssize_t line_len, co
139149
return NULL
140150

141151

142-
cdef extern from "string.h":
143-
int memcmp(const void *s1, const void *s2, size_t n)
144-
145-
146152
cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_types, Py_ssize_t sample_size=100):
147153
"""
148154
Fast JSONL decoder that extracts values using C string operations.
@@ -174,9 +180,9 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
174180
cdef dict result = {}
175181
cdef Py_ssize_t num_lines = 0
176182
cdef Py_ssize_t i
177-
cdef char* end_ptr
178183
cdef bytes value_bytes
179184
cdef str value_str
185+
cdef Py_ssize_t remaining
180186

181187
# Initialize column data lists
182188
for col in column_names:
@@ -186,7 +192,10 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
186192
# Count lines first
187193
cdef const char* newline_pos = pos
188194
while newline_pos < end:
189-
newline_pos = <const char*>memchr(newline_pos, b'\n', end - newline_pos)
195+
remaining = end - newline_pos
196+
if remaining <= 0:
197+
break
198+
newline_pos = <const char*>memchr(newline_pos, b'\n', <size_t>remaining)
190199
if newline_pos == NULL:
191200
break
192201
num_lines += 1
@@ -201,7 +210,10 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
201210
for i in range(num_lines):
202211
# Find line end
203212
line_start = pos
204-
line_end = <const char*>memchr(line_start, b'\n', end - line_start)
213+
remaining = end - line_start
214+
if remaining <= 0:
215+
break
216+
line_end = <const char*>memchr(line_start, b'\n', <size_t>remaining)
205217
if line_end == NULL:
206218
line_end = end
207219

@@ -226,6 +238,10 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
226238
result[col].append(None)
227239
continue
228240

241+
# Create a safe bytes object from the C pointer
242+
# This is crucial to avoid segfaults when slicing
243+
value_bytes = PyBytes_FromStringAndSize(value_ptr, value_len)
244+
229245
# Parse value based on type
230246
if col_type == 'bool':
231247
if value_len == 4 and memcmp(value_ptr, b"true", 4) == 0:
@@ -239,8 +255,6 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
239255
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
240256
result[col].append(None)
241257
else:
242-
# Use strtol for integer parsing
243-
value_bytes = value_ptr[:value_len]
244258
try:
245259
result[col].append(int(value_bytes))
246260
except ValueError:
@@ -250,8 +264,6 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
250264
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
251265
result[col].append(None)
252266
else:
253-
# Use strtod for float parsing
254-
value_bytes = value_ptr[:value_len]
255267
try:
256268
result[col].append(float(value_bytes))
257269
except ValueError:
@@ -260,11 +272,12 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
260272
elif col_type == 'str':
261273
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
262274
result[col].append(None)
263-
elif value_ptr[0] == b'"':
275+
elif value_ptr[0] == b'"' and value_len >= 2:
264276
# String value - extract without quotes
265-
value_bytes = value_ptr[1:value_len-1]
277+
# Safely extract the string content
278+
string_content = PyBytes_FromStringAndSize(value_ptr + 1, value_len - 2)
266279
try:
267-
value_str = value_bytes.decode('utf-8')
280+
value_str = string_content.decode('utf-8')
268281
# Simple unescape
269282
value_str = value_str.replace('\\n', '\n').replace('\\t', '\t').replace('\\"', '"').replace('\\\\', '\\')
270283
result[col].append(value_str)
@@ -275,7 +288,6 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
275288

276289
else:
277290
# For other types (list, dict, null), fall back to Python
278-
value_bytes = value_ptr[:value_len]
279291
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
280292
result[col].append(None)
281293
else:
@@ -293,3 +305,9 @@ cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_ty
293305
pos = line_end + 1 if line_end < end else end
294306

295307
return (num_lines, len(column_names), result)
308+
309+
310+
# Declare the C function we need
311+
cdef extern from "Python.h":
312+
bytes PyBytes_FromStringAndSize(const char *v, Py_ssize_t len)
313+

0 commit comments

Comments
 (0)