-
Notifications
You must be signed in to change notification settings - Fork 47
Open
Description
from @curtischong in #428 (comment):
gemini gave me this code (I didn't forsee the shifting logic)
import numpy as np
import io # Used for the example
def decode_leb128_vectorized(byte_array: np.ndarray, count: int) -> tuple[np.ndarray, int]:
"""
Decodes a NumPy array of bytes containing multiple LEB128 integers.
Args:
byte_array: A 1D NumPy array of uint8 bytes.
count: The number of integers to decode.
Returns:
A tuple containing:
- A 1D NumPy array (uint32) of the decoded integers.
- The total number of bytes consumed from the input array.
"""
# Find the end byte of each number (MSB is 0)
is_end_byte = (byte_array & 128) == 0
end_indices = np.where(is_end_byte)[0]
if len(end_indices) < count:
raise ValueError(f"Data stream is too short. Expected {count} numbers, but found only {len(end_indices)}.")
# Trim to only the bytes needed for the required count of numbers
last_byte_idx = end_indices[count - 1]
bytes_consumed = last_byte_idx + 1
# We only need to work with the relevant slice of the array
byte_slice = byte_array[:bytes_consumed]
end_indices = end_indices[:count]
start_indices = np.insert(end_indices[:-1] + 1, 0, 0)
# --- Vectorized Reconstruction ---
# 1. Get the 7-bit data payloads for each byte
payloads = (byte_slice & 127).astype(np.uint32)
# 2. Identify which number (or "group") each byte belongs to
group_ids = np.zeros(len(byte_slice), dtype=np.int32)
group_ids[start_indices] = 1
group_ids = np.cumsum(group_ids) - 1
# 3. Find the position of each byte within its own number (0, 1, 2, ...)
byte_pos_in_group = np.arange(len(byte_slice)) - start_indices[group_ids]
# 4. Calculate the bit shift (0, 7, 14, ...) for each byte
shifts = byte_pos_in_group * 7
# 5. Apply shifts to all payloads at once
shifted_payloads = payloads << shifts
# 6. Sum the shifted payloads at the start of each number's sequence
decoded_values = np.add.reduceat(shifted_payloads, start_indices)
return decoded_values.astype(np.uint32), bytes_consumedand the usage is
# --- Example Usage ---
# Assume 'self._stream' is your byte stream object (e.g., a file handle)
# For this example, let's create a dummy stream:
# Let's encode two numbers: 624485 (requires 3 bytes) and 127 (requires 1 byte)
# 624485 -> 0xE5 0x8E 0x26
# 127 -> 0x7F
raw_bytes = bytes([0xE5, 0x8E, 0x26, 0x7F])
# In your code, this would be self._stream
mock_stream = io.BytesIO(raw_bytes)
# --- New, fast data loading logic ---
# Define the shape of the final data
height, width, bins = 1, 2, 1
total_numbers_to_read = height * width * bins
# 1. Read a block of data from the stream.
# For a real file, you might read the whole thing if it's not too large.
byte_data_from_stream = mock_stream.read()
byte_array = np.frombuffer(byte_data_from_stream, dtype=np.uint8)
# 2. Decode the byte array in one shot.
decoded_1d_array, bytes_consumed = decode_leb128_vectorized(byte_array, total_numbers_to_read)
# 3. Reshape the resulting 1D array into your desired 3D shape.
data = decoded_1d_array.reshape((height, width, bins))
print("Decoded data:\n", data)
print("\nValues:", data.flatten())
print("Bytes Consumed:", bytes_consumed)
# Expected output:
# Decoded data:
# [[[624485]
# [ 127]]]
#
# Values: [624485 127]
# Bytes Consumed: 4At first glance this look alright. but we need a test to properly verify this. Gemini assumes that we do not know how many bytes is the entire bzip2 buffer but I suspect that we do know (so we can simplify this)
Originally posted by @curtischong in #428 (comment)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels