|
| 1 | +# cython: language_level=3 |
| 2 | +# cython: nonecheck=False |
| 3 | +# cython: cdivision=True |
| 4 | +# cython: initializedcheck=False |
| 5 | +# cython: infer_types=True |
| 6 | +# cython: wraparound=False |
| 7 | +# cython: boundscheck=False |
| 8 | + |
| 9 | +""" |
| 10 | +Fast JSONL decoder using Cython for performance-critical operations. |
| 11 | +
|
| 12 | +This decoder uses native C string operations instead of regex for better performance. |
| 13 | +""" |
| 14 | + |
| 15 | +import numpy |
| 16 | +cimport numpy |
| 17 | +numpy.import_array() |
| 18 | + |
| 19 | +from libc.string cimport memchr, strlen, strstr |
| 20 | +from libc.stdlib cimport strtod, strtol, atoi |
| 21 | +from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_GET_SIZE |
| 22 | +from libc.stdint cimport int64_t |
| 23 | + |
| 24 | +import pyarrow |
| 25 | +from opteryx.third_party.tktech import csimdjson as simdjson |
| 26 | + |
| 27 | + |
| 28 | +cdef inline const char* find_key_value(const char* line, Py_ssize_t line_len, const char* key, Py_ssize_t key_len, Py_ssize_t* value_start, Py_ssize_t* value_len) nogil: |
| 29 | + """ |
| 30 | + Find the value for a given key in a JSON line. |
| 31 | + |
| 32 | + Returns pointer to value start, or NULL if not found. |
| 33 | + Updates value_start and value_len with the position and length. |
| 34 | + """ |
| 35 | + cdef const char* pos = line |
| 36 | + cdef const char* end = line + line_len |
| 37 | + cdef const char* key_pos |
| 38 | + cdef const char* value_pos |
| 39 | + cdef const char* quote_start |
| 40 | + cdef const char* quote_end |
| 41 | + cdef char first_char |
| 42 | + cdef int brace_count |
| 43 | + cdef int bracket_count |
| 44 | + |
| 45 | + # Search for the key pattern: "key": |
| 46 | + while pos < end: |
| 47 | + # Find opening quote of a key |
| 48 | + key_pos = <const char*>memchr(pos, b'"', end - pos) |
| 49 | + if key_pos == NULL: |
| 50 | + return NULL |
| 51 | + |
| 52 | + key_pos += 1 # Move past the opening quote |
| 53 | + |
| 54 | + # Check if this matches our key |
| 55 | + if (end - key_pos >= key_len and |
| 56 | + memcmp(key_pos, key, key_len) == 0 and |
| 57 | + key_pos[key_len] == b'"'): |
| 58 | + |
| 59 | + # Found the key, now find the colon |
| 60 | + value_pos = key_pos + key_len + 1 # Skip closing quote |
| 61 | + |
| 62 | + # Skip whitespace and colon |
| 63 | + while value_pos < end and (value_pos[0] == b' ' or value_pos[0] == b'\t' or value_pos[0] == b':'): |
| 64 | + value_pos += 1 |
| 65 | + |
| 66 | + if value_pos >= end: |
| 67 | + return NULL |
| 68 | + |
| 69 | + first_char = value_pos[0] |
| 70 | + value_start[0] = value_pos - line |
| 71 | + |
| 72 | + # Determine value type and find end |
| 73 | + if first_char == b'"': |
| 74 | + # String value - find closing quote, handling escapes |
| 75 | + quote_start = value_pos + 1 |
| 76 | + quote_end = quote_start |
| 77 | + while quote_end < end: |
| 78 | + if quote_end[0] == b'"' and (quote_end == quote_start or quote_end[-1] != b'\\'): |
| 79 | + value_len[0] = (quote_end + 1) - value_pos |
| 80 | + return value_pos |
| 81 | + quote_end += 1 |
| 82 | + return NULL |
| 83 | + |
| 84 | + elif first_char == b'{': |
| 85 | + # Object - count braces |
| 86 | + brace_count = 1 |
| 87 | + quote_end = value_pos + 1 |
| 88 | + while quote_end < end and brace_count > 0: |
| 89 | + if quote_end[0] == b'{': |
| 90 | + brace_count += 1 |
| 91 | + elif quote_end[0] == b'}': |
| 92 | + brace_count -= 1 |
| 93 | + quote_end += 1 |
| 94 | + value_len[0] = quote_end - value_pos |
| 95 | + return value_pos |
| 96 | + |
| 97 | + elif first_char == b'[': |
| 98 | + # Array - count brackets |
| 99 | + bracket_count = 1 |
| 100 | + quote_end = value_pos + 1 |
| 101 | + while quote_end < end and bracket_count > 0: |
| 102 | + if quote_end[0] == b'[': |
| 103 | + bracket_count += 1 |
| 104 | + elif quote_end[0] == b']': |
| 105 | + bracket_count -= 1 |
| 106 | + quote_end += 1 |
| 107 | + value_len[0] = quote_end - value_pos |
| 108 | + return value_pos |
| 109 | + |
| 110 | + elif first_char == b'n': |
| 111 | + # null |
| 112 | + if end - value_pos >= 4 and memcmp(value_pos, b"null", 4) == 0: |
| 113 | + value_len[0] = 4 |
| 114 | + return value_pos |
| 115 | + return NULL |
| 116 | + |
| 117 | + elif first_char == b't': |
| 118 | + # true |
| 119 | + if end - value_pos >= 4 and memcmp(value_pos, b"true", 4) == 0: |
| 120 | + value_len[0] = 4 |
| 121 | + return value_pos |
| 122 | + return NULL |
| 123 | + |
| 124 | + elif first_char == b'f': |
| 125 | + # false |
| 126 | + if end - value_pos >= 5 and memcmp(value_pos, b"false", 5) == 0: |
| 127 | + value_len[0] = 5 |
| 128 | + return value_pos |
| 129 | + return NULL |
| 130 | + |
| 131 | + else: |
| 132 | + # Number - find end (space, comma, brace, bracket) |
| 133 | + quote_end = value_pos + 1 |
| 134 | + while quote_end < end: |
| 135 | + if quote_end[0] in (b' ', b',', b'}', b']', b'\t', b'\n'): |
| 136 | + break |
| 137 | + quote_end += 1 |
| 138 | + value_len[0] = quote_end - value_pos |
| 139 | + return value_pos |
| 140 | + |
| 141 | + pos = key_pos |
| 142 | + |
| 143 | + return NULL |
| 144 | + |
| 145 | + |
| 146 | +cdef extern from "string.h": |
| 147 | + int memcmp(const void *s1, const void *s2, size_t n) |
| 148 | + |
| 149 | + |
| 150 | +cpdef fast_jsonl_decode_columnar(bytes buffer, list column_names, dict column_types, Py_ssize_t sample_size=100): |
| 151 | + """ |
| 152 | + Fast JSONL decoder that extracts values using C string operations. |
| 153 | + |
| 154 | + Parameters: |
| 155 | + buffer: bytes - The JSONL data |
| 156 | + column_names: list - List of column names to extract |
| 157 | + column_types: dict - Dictionary mapping column names to types ('bool', 'int', 'float', 'str', etc.) |
| 158 | + sample_size: int - Number of lines to use for schema inference (not used if column_types provided) |
| 159 | + |
| 160 | + Returns: |
| 161 | + tuple: (num_rows, num_cols, dict of column_name -> list of values) |
| 162 | + """ |
| 163 | + cdef const char* data = PyBytes_AS_STRING(buffer) |
| 164 | + cdef Py_ssize_t data_len = PyBytes_GET_SIZE(buffer) |
| 165 | + cdef const char* line_start |
| 166 | + cdef const char* line_end |
| 167 | + cdef const char* pos = data |
| 168 | + cdef const char* end = data + data_len |
| 169 | + cdef Py_ssize_t line_len |
| 170 | + cdef Py_ssize_t value_start |
| 171 | + cdef Py_ssize_t value_len |
| 172 | + cdef const char* value_ptr |
| 173 | + cdef bytes key_bytes |
| 174 | + cdef const char* key_ptr |
| 175 | + cdef Py_ssize_t key_len |
| 176 | + cdef str col_type |
| 177 | + cdef list column_data = [] |
| 178 | + cdef dict result = {} |
| 179 | + cdef Py_ssize_t num_lines = 0 |
| 180 | + cdef Py_ssize_t i |
| 181 | + cdef char* end_ptr |
| 182 | + cdef bytes value_bytes |
| 183 | + cdef str value_str |
| 184 | + |
| 185 | + # Initialize column data lists |
| 186 | + for col in column_names: |
| 187 | + column_data.append([]) |
| 188 | + result[col] = column_data[-1] |
| 189 | + |
| 190 | + # Count lines first |
| 191 | + cdef const char* newline_pos = pos |
| 192 | + while newline_pos < end: |
| 193 | + newline_pos = <const char*>memchr(newline_pos, b'\n', end - newline_pos) |
| 194 | + if newline_pos == NULL: |
| 195 | + break |
| 196 | + num_lines += 1 |
| 197 | + newline_pos += 1 |
| 198 | + |
| 199 | + # If last line doesn't end with newline, count it |
| 200 | + if data_len > 0 and data[data_len - 1] != b'\n': |
| 201 | + num_lines += 1 |
| 202 | + |
| 203 | + # Process each line |
| 204 | + pos = data |
| 205 | + for i in range(num_lines): |
| 206 | + # Find line end |
| 207 | + line_start = pos |
| 208 | + line_end = <const char*>memchr(line_start, b'\n', end - line_start) |
| 209 | + if line_end == NULL: |
| 210 | + line_end = end |
| 211 | + |
| 212 | + line_len = line_end - line_start |
| 213 | + |
| 214 | + # Skip empty lines |
| 215 | + if line_len == 0: |
| 216 | + pos = line_end + 1 if line_end < end else end |
| 217 | + continue |
| 218 | + |
| 219 | + # Extract each column |
| 220 | + for j, col in enumerate(column_names): |
| 221 | + key_bytes = col.encode('utf-8') |
| 222 | + key_ptr = PyBytes_AS_STRING(key_bytes) |
| 223 | + key_len = PyBytes_GET_SIZE(key_bytes) |
| 224 | + col_type = column_types.get(col, 'str') |
| 225 | + |
| 226 | + value_ptr = find_key_value(line_start, line_len, key_ptr, key_len, &value_start, &value_len) |
| 227 | + |
| 228 | + if value_ptr == NULL: |
| 229 | + # Key not found |
| 230 | + result[col].append(None) |
| 231 | + continue |
| 232 | + |
| 233 | + # Parse value based on type |
| 234 | + if col_type == 'bool': |
| 235 | + if value_len == 4 and memcmp(value_ptr, b"true", 4) == 0: |
| 236 | + result[col].append(True) |
| 237 | + elif value_len == 5 and memcmp(value_ptr, b"false", 5) == 0: |
| 238 | + result[col].append(False) |
| 239 | + else: |
| 240 | + result[col].append(None) |
| 241 | + |
| 242 | + elif col_type == 'int': |
| 243 | + if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0: |
| 244 | + result[col].append(None) |
| 245 | + else: |
| 246 | + # Use strtol for integer parsing |
| 247 | + value_bytes = value_ptr[:value_len] |
| 248 | + try: |
| 249 | + result[col].append(int(value_bytes)) |
| 250 | + except ValueError: |
| 251 | + result[col].append(None) |
| 252 | + |
| 253 | + elif col_type == 'float': |
| 254 | + if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0: |
| 255 | + result[col].append(None) |
| 256 | + else: |
| 257 | + # Use strtod for float parsing |
| 258 | + value_bytes = value_ptr[:value_len] |
| 259 | + try: |
| 260 | + result[col].append(float(value_bytes)) |
| 261 | + except ValueError: |
| 262 | + result[col].append(None) |
| 263 | + |
| 264 | + elif col_type == 'str': |
| 265 | + if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0: |
| 266 | + result[col].append(None) |
| 267 | + elif value_ptr[0] == b'"': |
| 268 | + # String value - extract without quotes |
| 269 | + value_bytes = value_ptr[1:value_len-1] |
| 270 | + try: |
| 271 | + value_str = value_bytes.decode('utf-8') |
| 272 | + # Simple unescape |
| 273 | + value_str = value_str.replace('\\n', '\n').replace('\\t', '\t').replace('\\"', '"').replace('\\\\', '\\') |
| 274 | + result[col].append(value_str) |
| 275 | + except UnicodeDecodeError: |
| 276 | + result[col].append(None) |
| 277 | + else: |
| 278 | + result[col].append(None) |
| 279 | + |
| 280 | + else: |
| 281 | + # For other types (list, dict, null), fall back to Python |
| 282 | + value_bytes = value_ptr[:value_len] |
| 283 | + if value_len == 4 and memcmp(value_ptr, b"null", 4) == 0: |
| 284 | + result[col].append(None) |
| 285 | + else: |
| 286 | + import json |
| 287 | + try: |
| 288 | + parsed = json.loads(value_bytes.decode('utf-8')) |
| 289 | + if isinstance(parsed, dict): |
| 290 | + result[col].append(json.dumps(parsed, ensure_ascii=False)) |
| 291 | + else: |
| 292 | + result[col].append(parsed) |
| 293 | + except (json.JSONDecodeError, UnicodeDecodeError): |
| 294 | + result[col].append(None) |
| 295 | + |
| 296 | + # Move to next line |
| 297 | + pos = line_end + 1 if line_end < end else end |
| 298 | + |
| 299 | + return (num_lines, len(column_names), result) |
0 commit comments