Skip to content

Commit b88bfbc

Browse files
Copilotjoocer
andcommitted
Add Cython-based fast JSONL decoder for better performance
Co-authored-by: joocer <[email protected]>
1 parent 6dbd08d commit b88bfbc

File tree

3 files changed

+384
-0
lines changed

3 files changed

+384
-0
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# cython: language_level=3
2+
# cython: nonecheck=False
3+
# cython: cdivision=True
4+
# cython: initializedcheck=False
5+
# cython: infer_types=True
6+
# cython: wraparound=False
7+
# cython: boundscheck=False
8+
9+
"""
10+
Fast JSONL decoder using Cython for performance-critical operations.
11+
12+
This decoder uses native C string operations instead of regex for better performance.
13+
"""
14+
15+
import numpy
16+
cimport numpy
17+
numpy.import_array()
18+
19+
from libc.string cimport memchr, strlen, strstr
20+
from libc.stdlib cimport strtod, strtol, atoi
21+
from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_GET_SIZE
22+
from libc.stdint cimport int64_t
23+
24+
import pyarrow
25+
from opteryx.third_party.tktech import csimdjson as simdjson
26+
27+
28+
cdef inline const char* find_key_value(const char* line, Py_ssize_t line_len, const char* key, Py_ssize_t key_len, Py_ssize_t* value_start, Py_ssize_t* value_len) nogil:
29+
"""
30+
Find the value for a given key in a JSON line.
31+
32+
Returns pointer to value start, or NULL if not found.
33+
Updates value_start and value_len with the position and length.
34+
"""
35+
cdef const char* pos = line
36+
cdef const char* end = line + line_len
37+
cdef const char* key_pos
38+
cdef const char* value_pos
39+
cdef const char* quote_start
40+
cdef const char* quote_end
41+
cdef char first_char
42+
cdef int brace_count
43+
cdef int bracket_count
44+
45+
# Search for the key pattern: "key":
46+
while pos < end:
47+
# Find opening quote of a key
48+
key_pos = <const char*>memchr(pos, b'"', end - pos)
49+
if key_pos == NULL:
50+
return NULL
51+
52+
key_pos += 1 # Move past the opening quote
53+
54+
# Check if this matches our key
55+
if (end - key_pos >= key_len and
56+
memcmp(key_pos, key, key_len) == 0 and
57+
key_pos[key_len] == b'"'):
58+
59+
# Found the key, now find the colon
60+
value_pos = key_pos + key_len + 1 # Skip closing quote
61+
62+
# Skip whitespace and colon
63+
while value_pos < end and (value_pos[0] == b' ' or value_pos[0] == b'\t' or value_pos[0] == b':'):
64+
value_pos += 1
65+
66+
if value_pos >= end:
67+
return NULL
68+
69+
first_char = value_pos[0]
70+
value_start[0] = value_pos - line
71+
72+
# Determine value type and find end
73+
if first_char == b'"':
74+
# String value - find closing quote, handling escapes
75+
quote_start = value_pos + 1
76+
quote_end = quote_start
77+
while quote_end < end:
78+
if quote_end[0] == b'"' and (quote_end == quote_start or quote_end[-1] != b'\\'):
79+
value_len[0] = (quote_end + 1) - value_pos
80+
return value_pos
81+
quote_end += 1
82+
return NULL
83+
84+
elif first_char == b'{':
85+
# Object - count braces
86+
brace_count = 1
87+
quote_end = value_pos + 1
88+
while quote_end < end and brace_count > 0:
89+
if quote_end[0] == b'{':
90+
brace_count += 1
91+
elif quote_end[0] == b'}':
92+
brace_count -= 1
93+
quote_end += 1
94+
value_len[0] = quote_end - value_pos
95+
return value_pos
96+
97+
elif first_char == b'[':
98+
# Array - count brackets
99+
bracket_count = 1
100+
quote_end = value_pos + 1
101+
while quote_end < end and bracket_count > 0:
102+
if quote_end[0] == b'[':
103+
bracket_count += 1
104+
elif quote_end[0] == b']':
105+
bracket_count -= 1
106+
quote_end += 1
107+
value_len[0] = quote_end - value_pos
108+
return value_pos
109+
110+
elif first_char == b'n':
111+
# null
112+
if end - value_pos >= 4 and memcmp(value_pos, b"null", 4) == 0:
113+
value_len[0] = 4
114+
return value_pos
115+
return NULL
116+
117+
elif first_char == b't':
118+
# true
119+
if end - value_pos >= 4 and memcmp(value_pos, b"true", 4) == 0:
120+
value_len[0] = 4
121+
return value_pos
122+
return NULL
123+
124+
elif first_char == b'f':
125+
# false
126+
if end - value_pos >= 5 and memcmp(value_pos, b"false", 5) == 0:
127+
value_len[0] = 5
128+
return value_pos
129+
return NULL
130+
131+
else:
132+
# Number - find end (space, comma, brace, bracket)
133+
quote_end = value_pos + 1
134+
while quote_end < end:
135+
if quote_end[0] in (b' ', b',', b'}', b']', b'\t', b'\n'):
136+
break
137+
quote_end += 1
138+
value_len[0] = quote_end - value_pos
139+
return value_pos
140+
141+
pos = key_pos
142+
143+
return NULL
144+
145+
146+
cdef extern from "string.h":
147+
int memcmp(const void *s1, const void *s2, size_t n)
148+
149+
150+
cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_types, Py_ssize_t sample_size=100):
151+
"""
152+
Fast JSONL decoder that extracts values using C string operations.
153+
154+
Parameters:
155+
buffer: bytes - The JSONL data
156+
column_names: list - List of column names to extract
157+
column_types: dict - Dictionary mapping column names to types ('bool', 'int', 'float', 'str', etc.)
158+
sample_size: int - Number of lines to use for schema inference (not used if column_types provided)
159+
160+
Returns:
161+
tuple: (num_rows, num_cols, dict of column_name -> list of values)
162+
"""
163+
cdef const char* data = PyBytes_AS_STRING(buffer)
164+
cdef Py_ssize_t data_len = PyBytes_GET_SIZE(buffer)
165+
cdef const char* line_start
166+
cdef const char* line_end
167+
cdef const char* pos = data
168+
cdef const char* end = data + data_len
169+
cdef Py_ssize_t line_len
170+
cdef Py_ssize_t value_start
171+
cdef Py_ssize_t value_len
172+
cdef const char* value_ptr
173+
cdef bytes key_bytes
174+
cdef const char* key_ptr
175+
cdef Py_ssize_t key_len
176+
cdef str col_type
177+
cdef list column_data = []
178+
cdef dict result = {}
179+
cdef Py_ssize_t num_lines = 0
180+
cdef Py_ssize_t i
181+
cdef char* end_ptr
182+
cdef bytes value_bytes
183+
cdef str value_str
184+
185+
# Initialize column data lists
186+
for col in column_names:
187+
column_data.append([])
188+
result[col] = column_data[-1]
189+
190+
# Count lines first
191+
cdef const char* newline_pos = pos
192+
while newline_pos < end:
193+
newline_pos = <const char*>memchr(newline_pos, b'\n', end - newline_pos)
194+
if newline_pos == NULL:
195+
break
196+
num_lines += 1
197+
newline_pos += 1
198+
199+
# If last line doesn't end with newline, count it
200+
if data_len > 0 and data[data_len - 1] != b'\n':
201+
num_lines += 1
202+
203+
# Process each line
204+
pos = data
205+
for i in range(num_lines):
206+
# Find line end
207+
line_start = pos
208+
line_end = <const char*>memchr(line_start, b'\n', end - line_start)
209+
if line_end == NULL:
210+
line_end = end
211+
212+
line_len = line_end - line_start
213+
214+
# Skip empty lines
215+
if line_len == 0:
216+
pos = line_end + 1 if line_end < end else end
217+
continue
218+
219+
# Extract each column
220+
for j, col in enumerate(column_names):
221+
key_bytes = col.encode('utf-8')
222+
key_ptr = PyBytes_AS_STRING(key_bytes)
223+
key_len = PyBytes_GET_SIZE(key_bytes)
224+
col_type = column_types.get(col, 'str')
225+
226+
value_ptr = find_key_value(line_start, line_len, key_ptr, key_len, &value_start, &value_len)
227+
228+
if value_ptr == NULL:
229+
# Key not found
230+
result[col].append(None)
231+
continue
232+
233+
# Parse value based on type
234+
if col_type == 'bool':
235+
if value_len == 4 and memcmp(value_ptr, b"true", 4) == 0:
236+
result[col].append(True)
237+
elif value_len == 5 and memcmp(value_ptr, b"false", 5) == 0:
238+
result[col].append(False)
239+
else:
240+
result[col].append(None)
241+
242+
elif col_type == 'int':
243+
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
244+
result[col].append(None)
245+
else:
246+
# Use strtol for integer parsing
247+
value_bytes = value_ptr[:value_len]
248+
try:
249+
result[col].append(int(value_bytes))
250+
except ValueError:
251+
result[col].append(None)
252+
253+
elif col_type == 'float':
254+
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
255+
result[col].append(None)
256+
else:
257+
# Use strtod for float parsing
258+
value_bytes = value_ptr[:value_len]
259+
try:
260+
result[col].append(float(value_bytes))
261+
except ValueError:
262+
result[col].append(None)
263+
264+
elif col_type == 'str':
265+
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
266+
result[col].append(None)
267+
elif value_ptr[0] == b'"':
268+
# String value - extract without quotes
269+
value_bytes = value_ptr[1:value_len-1]
270+
try:
271+
value_str = value_bytes.decode('utf-8')
272+
# Simple unescape
273+
value_str = value_str.replace('\\n', '\n').replace('\\t', '\t').replace('\\"', '"').replace('\\\\', '\\')
274+
result[col].append(value_str)
275+
except UnicodeDecodeError:
276+
result[col].append(None)
277+
else:
278+
result[col].append(None)
279+
280+
else:
281+
# For other types (list, dict, null), fall back to Python
282+
value_bytes = value_ptr[:value_len]
283+
if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0:
284+
result[col].append(None)
285+
else:
286+
import json
287+
try:
288+
parsed = json.loads(value_bytes.decode('utf-8'))
289+
if isinstance(parsed, dict):
290+
result[col].append(json.dumps(parsed, ensure_ascii=False))
291+
else:
292+
result[col].append(parsed)
293+
except (json.JSONDecodeError, UnicodeDecodeError):
294+
result[col].append(None)
295+
296+
# Move to next line
297+
pos = line_end + 1 if line_end < end else end
298+
299+
return (num_lines, len(column_names), result)

opteryx/utils/file_decoders.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def jsonl_decoder(
436436
selection: Optional[list] = None,
437437
just_schema: bool = False,
438438
just_statistics: bool = False,
439+
use_fast_decoder: bool = True,
439440
**kwargs,
440441
) -> Tuple[int, int, pyarrow.Table]:
441442
if just_statistics:
@@ -456,6 +457,85 @@ def jsonl_decoder(
456457
table = pyarrow.Table.from_arrays([[num_rows]], names=["$COUNT(*)"])
457458
return (num_rows, 0, 0, table)
458459

460+
# Try fast Cython decoder for large files with no selection filters
461+
if use_fast_decoder and not just_schema and not selection and len(buffer) > 10000:
462+
try:
463+
from opteryx.compiled.structures import jsonl_decoder as cython_decoder
464+
465+
# Sample first 100 lines to infer schema
466+
parser = simdjson.Parser()
467+
sample_size = min(100, buffer.count(b"\n"))
468+
sample_records = []
469+
keys_union = set()
470+
471+
start = 0
472+
for _ in range(sample_size):
473+
newline = buffer.find(b"\n", start)
474+
if newline == -1:
475+
break
476+
line = buffer[start:newline]
477+
start = newline + 1
478+
if line:
479+
try:
480+
record = parser.parse(line)
481+
row = record.as_dict()
482+
sample_records.append(row)
483+
keys_union.update(row.keys())
484+
except Exception:
485+
continue
486+
487+
if sample_records:
488+
# Infer column types from sample
489+
column_types = {}
490+
columns_to_extract = list(keys_union)
491+
492+
if projection:
493+
# If projection specified, only extract those columns
494+
columns_to_extract = [c.value for c in projection if c.value in keys_union]
495+
496+
for key in columns_to_extract:
497+
for record in sample_records:
498+
if key in record and record[key] is not None:
499+
val = record[key]
500+
if isinstance(val, bool):
501+
column_types[key] = 'bool'
502+
elif isinstance(val, int):
503+
column_types[key] = 'int'
504+
elif isinstance(val, float):
505+
column_types[key] = 'float'
506+
elif isinstance(val, str):
507+
column_types[key] = 'str'
508+
elif isinstance(val, list):
509+
column_types[key] = 'list'
510+
elif isinstance(val, dict):
511+
column_types[key] = 'dict'
512+
break
513+
if key not in column_types:
514+
column_types[key] = 'str' # Default to string
515+
516+
# Use Cython decoder
517+
num_rows, num_cols, column_data = cython_decoder.fast_jsonl_decode_columnar(
518+
buffer, columns_to_extract, column_types, sample_size
519+
)
520+
521+
# Convert to PyArrow table
522+
arrays = []
523+
names = []
524+
for key in sorted(columns_to_extract):
525+
arrays.append(pyarrow.array(column_data[key]))
526+
names.append(key)
527+
528+
if arrays:
529+
table = pyarrow.Table.from_arrays(arrays, names=names)
530+
if projection:
531+
table = post_read_projector(table, projection)
532+
return num_rows, num_cols, 0, table
533+
534+
except (ImportError, Exception) as e:
535+
# Fall back to standard decoder if Cython version fails
536+
import warnings
537+
warnings.warn(f"Fast JSONL decoder failed, using standard decoder: {e}")
538+
459539
parser = simdjson.Parser()
460540

461541
# preallocate and reuse dicts

0 commit comments

Comments
 (0)