Skip to content

Commit ec12d1c

Browse files
committed
[feat] support struct-of-array (soa) in spec draft-4
1 parent 67f43fb commit ec12d1c

File tree

11 files changed

+2647
-142
lines changed

11 files changed

+2647
-142
lines changed

bjdata/decoder.py

Lines changed: 151 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616

17-
"""BJData (Draft 2) and UBJSON encoder"""
17+
"""BJData (Draft 4) and UBJSON decoder with SOA support"""
1818

1919
from io import BytesIO
2020
from struct import Struct, pack, error as StructError
@@ -55,6 +55,8 @@
5555
dtype as npdtype,
5656
frombuffer as buffer2numpy,
5757
half as halfprec,
58+
zeros as npzeros,
59+
empty as npempty,
5860
)
5961
from array import array as typedarray
6062

@@ -164,6 +166,24 @@
164166
TYPE_CHAR: 1,
165167
}
166168

169+
# Numpy dtype strings for SOA
170+
__NUMPY_DTYPE_MAP = {
171+
TYPE_BYTE: "u1",
172+
TYPE_INT8: "i1",
173+
TYPE_UINT8: "u1",
174+
TYPE_INT16: "i2",
175+
TYPE_UINT16: "u2",
176+
TYPE_INT32: "i4",
177+
TYPE_UINT32: "u4",
178+
TYPE_INT64: "i8",
179+
TYPE_UINT64: "u8",
180+
TYPE_FLOAT16: "f2",
181+
TYPE_FLOAT32: "f4",
182+
TYPE_FLOAT64: "f8",
183+
TYPE_BOOL_TRUE: "?",
184+
TYPE_BOOL_FALSE: "?",
185+
}
186+
167187

168188
class DecoderException(ValueError):
169189
"""Raised when decoding of a UBJSON stream fails."""
@@ -358,6 +378,106 @@ def prodlist(mylist):
358378
return result
359379

360380

381+
def __decode_soa_schema(fp_read, intern_object_keys, le):
382+
"""Decode SOA schema: {field1:type1, field2:type2, ...}"""
383+
schema = []
384+
marker = fp_read(1)
385+
386+
while marker != OBJECT_END:
387+
if marker == TYPE_NOOP:
388+
marker = fp_read(1)
389+
continue
390+
391+
# Decode field name
392+
field_name = __decode_object_key(fp_read, marker, intern_object_keys, le)
393+
394+
# Decode field type marker
395+
type_marker = fp_read(1)
396+
if type_marker not in __TYPES_FIXLEN and type_marker not in (
397+
TYPE_BOOL_TRUE,
398+
TYPE_BOOL_FALSE,
399+
):
400+
raise DecoderException("SOA schema only supports fixed-length types")
401+
402+
schema.append((field_name, type_marker))
403+
marker = fp_read(1)
404+
405+
return schema
406+
407+
408+
def __decode_soa(fp_read, schema, is_row_major, intern_object_keys, le):
409+
"""Decode SOA payload into numpy structured array"""
410+
# Read count (can be scalar or ND dimensions)
411+
marker = fp_read(1)
412+
if marker != CONTAINER_COUNT:
413+
raise DecoderException("Expected # after SOA schema")
414+
415+
marker = fp_read(1)
416+
if marker == ARRAY_START:
417+
# ND dimensions
418+
dims = []
419+
marker = fp_read(1)
420+
while marker != ARRAY_END:
421+
if marker in __TYPES_INT:
422+
dims.append(__METHOD_MAP[marker](fp_read, marker, le))
423+
marker = fp_read(1)
424+
count = prodlist(dims)
425+
else:
426+
# Scalar count
427+
count = __decode_int_non_negative(fp_read, marker, le)
428+
dims = [count]
429+
430+
# Build numpy dtype for structured array
431+
dtype_list = []
432+
for field_name, type_marker in schema:
433+
if type_marker in (TYPE_BOOL_TRUE, TYPE_BOOL_FALSE):
434+
dtype_list.append((field_name, "?"))
435+
else:
436+
dtype_list.append((field_name, __NUMPY_DTYPE_MAP[type_marker]))
437+
438+
struct_dtype = npdtype(dtype_list)
439+
result = npempty(count, dtype=struct_dtype)
440+
441+
if is_row_major:
442+
# Row-major: interleaved - read one record at a time
443+
for i in range(count):
444+
for field_name, type_marker in schema:
445+
if type_marker in (TYPE_BOOL_TRUE, TYPE_BOOL_FALSE):
446+
# Boolean: read T or F byte
447+
bool_byte = fp_read(1)
448+
result[field_name][i] = bool_byte == TYPE_BOOL_TRUE
449+
else:
450+
# Numeric: read raw bytes
451+
nbytes = __DTYPELEN_MAP[type_marker]
452+
raw = fp_read(nbytes)
453+
value = buffer2numpy(
454+
raw, dtype=npdtype(__NUMPY_DTYPE_MAP[type_marker])
455+
)[0]
456+
result[field_name][i] = value
457+
else:
458+
# Column-major: all values of field1, then field2, etc.
459+
for field_name, type_marker in schema:
460+
if type_marker in (TYPE_BOOL_TRUE, TYPE_BOOL_FALSE):
461+
# Boolean: read T/F bytes
462+
bool_bytes = fp_read(count)
463+
for i in range(count):
464+
result[field_name][i] = bool_bytes[i : i + 1] == TYPE_BOOL_TRUE
465+
else:
466+
# Numeric: read all values at once
467+
nbytes = __DTYPELEN_MAP[type_marker]
468+
raw = fp_read(count * nbytes)
469+
values = buffer2numpy(
470+
raw, dtype=npdtype(__NUMPY_DTYPE_MAP[type_marker])
471+
)
472+
result[field_name] = values
473+
474+
# Reshape if ND
475+
if len(dims) > 1:
476+
result = result.reshape(dims)
477+
478+
return result
479+
480+
361481
def __get_container_params(
362482
fp_read,
363483
in_mapping,
@@ -372,6 +492,13 @@ def __get_container_params(
372492
dims = []
373493
if marker == CONTAINER_TYPE:
374494
marker = fp_read(1)
495+
496+
# Check for SOA: ${ indicates schema object
497+
if marker == OBJECT_START:
498+
# This is SOA format - decode schema
499+
schema = __decode_soa_schema(fp_read, intern_object_keys, islittle)
500+
return marker, True, -1, schema, [], True # -1 count signals SOA
501+
375502
if marker not in __TYPES:
376503
raise DecoderException("Invalid container type")
377504
type_ = marker
@@ -413,7 +540,7 @@ def __get_container_params(
413540
counting = False
414541
else:
415542
raise DecoderException("Container type without count")
416-
return marker, counting, count, type_, dims
543+
return marker, counting, count, type_, dims, False
417544

418545

419546
def __decode_object(
@@ -425,7 +552,7 @@ def __decode_object(
425552
intern_object_keys,
426553
islittle,
427554
):
428-
marker, counting, count, type_, dims = __get_container_params(
555+
result = __get_container_params(
429556
fp_read,
430557
True,
431558
no_bytes,
@@ -435,6 +562,13 @@ def __decode_object(
435562
intern_object_keys,
436563
islittle,
437564
)
565+
566+
# Check if this is SOA format
567+
if len(result) == 6 and result[5]: # is_soa flag
568+
schema = result[3]
569+
return __decode_soa(fp_read, schema, False, intern_object_keys, islittle)
570+
571+
marker, counting, count, type_, dims, _ = result
438572
has_pairs_hook = object_pairs_hook is not None
439573
obj = [] if has_pairs_hook else {}
440574

@@ -524,7 +658,7 @@ def __decode_array(
524658
intern_object_keys,
525659
islittle,
526660
):
527-
marker, counting, count, type_, dims = __get_container_params(
661+
result = __get_container_params(
528662
fp_read,
529663
False,
530664
no_bytes,
@@ -535,6 +669,13 @@ def __decode_array(
535669
islittle,
536670
)
537671

672+
# Check if this is SOA format (row-major)
673+
if len(result) == 6 and result[5]: # is_soa flag
674+
schema = result[3]
675+
return __decode_soa(fp_read, schema, True, intern_object_keys, islittle)
676+
677+
marker, counting, count, type_, dims, _ = result
678+
538679
# special case - no data (None or bool)
539680
if type_ in __TYPES_NO_DATA:
540681
return [__METHOD_MAP[type_](fp_read, type_, islittle)] * count
@@ -566,7 +707,7 @@ def __decode_array(
566707
container = buffer2numpy(container, dtype=npdtype(__DTYPE_MAP[type_]))
567708
return container
568709

569-
container = []
710+
container = list()
570711
while count > 0 and (counting or marker != ARRAY_END):
571712
if marker == TYPE_NOOP:
572713
marker = fp_read(1)
@@ -645,7 +786,7 @@ def load(
645786
any other array (i.e. result in a list).
646787
uint8_bytes (bool): If set, typed UBJSON arrays (uint8) will be
647788
converted to a bytes instance instead of being
648-
treated as an array (for UBJSON & BJData Draft 2).
789+
treated as an array (for UBJSON & BJData Draft 4).
649790
Ignored if no_bytes is set.
650791
object_hook (callable): Called with the result of any object literal
651792
decoded (instead of dict).
@@ -659,7 +800,7 @@ def load(
659800
in Python2 (since interning does not apply
660801
to unicode) and wil be ignored.
661802
islittle (1 or 0): default is 1 for little-endian for all numerics (for
662-
BJData Draft 2), change to 0 to use big-endian
803+
BJData Draft 4), change to 0 to use big-endian
663804
(for UBJSON & BJData Draft 1)
664805
665806
Returns:
@@ -697,6 +838,9 @@ def load(
697838
+----------------------------------+---------------+
698839
| null | None |
699840
+----------------------------------+---------------+
841+
842+
SOA (Structure of Arrays) format is automatically detected and decoded
843+
to numpy structured arrays (record arrays).
700844
"""
701845
if object_pairs_hook is None and object_hook is None:
702846
object_hook = __object_hook_noop

0 commit comments

Comments
 (0)