|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | + |
| 4 | +""" |
| 5 | +tf2onnx.flexbuffers - Code for parsing flexbuffers |
| 6 | +""" |
| 7 | + |
| 8 | +import struct |
| 9 | + |
| 10 | + |
| 11 | +class FlexbufferParseException(Exception): |
| 12 | + pass |
| 13 | + |
| 14 | + |
| 15 | +def read_int(buffer, offset, bit_size): |
| 16 | + size = 1 << bit_size |
| 17 | + format_char = 'bhiq'[bit_size] |
| 18 | + return struct.unpack('<' + format_char, buffer[offset:offset+size])[0] |
| 19 | + |
| 20 | + |
| 21 | +def read_uint(buffer, offset, bit_size): |
| 22 | + size = 1 << bit_size |
| 23 | + format_char = 'BHIQ'[bit_size] |
| 24 | + return struct.unpack('<' + format_char, buffer[offset:offset+size])[0] |
| 25 | + |
| 26 | + |
| 27 | +def read_float(buffer, offset, bit_size): |
| 28 | + if bit_size == 2: |
| 29 | + return struct.unpack('<f', buffer[offset:offset+4])[0] |
| 30 | + if bit_size == 3: |
| 31 | + return struct.unpack('<d', buffer[offset:offset+8])[0] |
| 32 | + raise FlexbufferParseException("Invalid bit size for flexbuffer float: %d" % bit_size) |
| 33 | + |
| 34 | + |
| 35 | +def read_string(buffer, offset, size): |
| 36 | + return buffer[offset:offset+size].decode('utf-8') |
| 37 | + |
| 38 | + |
| 39 | +def read_indirect(buffer, offset, bit_size): |
| 40 | + return offset - read_uint(buffer, offset, bit_size) |
| 41 | + |
| 42 | + |
| 43 | +def read_bytes(buffer, offset, size): |
| 44 | + return buffer[offset:offset+size] |
| 45 | + |
| 46 | + |
| 47 | +def read_array(buffer, offset, length, bit_size, packed_type): |
| 48 | + byte_size = 1 << bit_size |
| 49 | + arr = [] |
| 50 | + for i in range(length): |
| 51 | + item_offset = offset + (i * byte_size) |
| 52 | + arr.append(read_buffer(buffer, item_offset, bit_size, packed_type)) |
| 53 | + return arr |
| 54 | + |
| 55 | + |
| 56 | +def read_buffer(buffer, offset, parent_bit_size, packed_type): |
| 57 | + """Recursively decode flatbuffer object into python representation""" |
| 58 | + bit_size = packed_type & 3 |
| 59 | + value_type = packed_type >> 2 |
| 60 | + byte_size = 1 << bit_size |
| 61 | + |
| 62 | + if value_type == 0x0: |
| 63 | + return None |
| 64 | + if value_type in [0x1, 0x2, 0x3]: |
| 65 | + read_fn = {0x1: read_int, 0x2: read_uint, 0x3: read_float}[value_type] |
| 66 | + return read_fn(buffer, offset, parent_bit_size) |
| 67 | + if value_type in [0x4, 0x5]: |
| 68 | + str_offset = read_indirect(buffer, offset, parent_bit_size) |
| 69 | + size = 0 |
| 70 | + while read_int(buffer, str_offset + size, 0) != 0: |
| 71 | + size += 1 |
| 72 | + return read_string(buffer, str_offset, size) |
| 73 | + if value_type == 0x5: |
| 74 | + str_offset = read_indirect(buffer, offset, parent_bit_size) |
| 75 | + size_byte_size = 1 << bit_size |
| 76 | + size = read_uint(buffer, str_offset - size_byte_size, bit_size) |
| 77 | + while read_int(buffer, str_offset + size, 0) != 0: |
| 78 | + size_byte_size <<= 1 |
| 79 | + size = read_uint(buffer, str_offset - size_byte_size, bit_size) |
| 80 | + return read_string(buffer, str_offset, size) |
| 81 | + if value_type in [0x6, 0x7, 0x8]: |
| 82 | + read_fn = {0x6: read_int, 0x7: read_uint, 0x8: read_float}[value_type] |
| 83 | + data_offset = read_indirect(buffer, offset, parent_bit_size) |
| 84 | + return read_fn(buffer, data_offset, bit_size) |
| 85 | + if value_type == 0x9: |
| 86 | + length = read_uint(buffer, read_indirect(buffer, offset, parent_bit_size) - byte_size, bit_size) |
| 87 | + keys_offset = read_indirect(buffer, offset, parent_bit_size) - (byte_size * 3) |
| 88 | + keys_vector_offset = read_indirect(buffer, keys_offset, bit_size) |
| 89 | + key_byte_size = read_uint(buffer, keys_offset + byte_size, bit_size) |
| 90 | + key_bit_size = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4}[key_byte_size] |
| 91 | + values_offset = read_indirect(buffer, offset, parent_bit_size) |
| 92 | + packed_types_offset = values_offset + length * byte_size |
| 93 | + obj = {} |
| 94 | + for i in range(length): |
| 95 | + key_offset = keys_vector_offset + i * key_byte_size |
| 96 | + key = read_buffer(buffer, key_offset, key_bit_size, (0x4 << 2) | key_bit_size) |
| 97 | + value_offset = values_offset + i * byte_size |
| 98 | + value_packed_type = read_uint(buffer, packed_types_offset + i, 0) |
| 99 | + value = read_buffer(buffer, value_offset, bit_size, value_packed_type) |
| 100 | + obj[key] = value |
| 101 | + return obj |
| 102 | + if value_type == 0xa: |
| 103 | + length = read_uint(buffer, read_indirect(buffer, offset, parent_bit_size) - byte_size, bit_size) |
| 104 | + arr = [] |
| 105 | + items_offset = read_indirect(buffer, offset, parent_bit_size) |
| 106 | + packed_types_offset = items_offset + (length * byte_size) |
| 107 | + for i in range(length): |
| 108 | + item_offset = items_offset + (i * byte_size) |
| 109 | + packed_type = read_uint(buffer, packed_types_offset + i, 0) |
| 110 | + arr.append(read_buffer(buffer, item_offset, bit_size, packed_type)) |
| 111 | + return arr |
| 112 | + if value_type in [0xb, 0xc, 0xd, 0xe, 0xf, 0x24]: |
| 113 | + length_offset = read_indirect(buffer, offset, parent_bit_size) - byte_size |
| 114 | + length = read_uint(buffer, length_offset, bit_size) |
| 115 | + item_value_type = value_type - 0xb + 0x1 |
| 116 | + packed_type = item_value_type << 2 |
| 117 | + items_offset = read_indirect(buffer, offset, parent_bit_size) |
| 118 | + return read_array(buffer, items_offset, length, bit_size, packed_type) |
| 119 | + if 0x10 <= value_type <= 0x18: |
| 120 | + length = (value_type - 0x10) // 3 + 2 |
| 121 | + value_type = ((value_type - 0x10) % 3) + 1 |
| 122 | + packed_type = value_type << 2 |
| 123 | + items_offset = read_indirect(buffer, offset, parent_bit_size) |
| 124 | + return read_array(buffer, items_offset, length, bit_size, packed_type) |
| 125 | + if value_type == 0x19: |
| 126 | + data_offset = read_indirect(buffer, offset, parent_bit_size) |
| 127 | + size_offset = data_offset - byte_size |
| 128 | + size = read_uint(buffer, size_offset, bit_size) |
| 129 | + return read_bytes(buffer, data_offset, size) |
| 130 | + if value_type == 0x1a: |
| 131 | + return read_uint(buffer, offset, parent_bit_size) > 0 |
| 132 | + raise FlexbufferParseException("Invalid flexbuffer value type %r" % value_type) |
| 133 | + |
| 134 | + |
| 135 | +def read_flexbuffer(buffer): |
| 136 | + byte_size = read_uint(buffer, len(buffer) - 1, 0) |
| 137 | + bit_size = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4}[byte_size] |
| 138 | + packed_type = read_uint(buffer, len(buffer) - 2, 0) |
| 139 | + offset = len(buffer) - 2 - byte_size |
| 140 | + return read_buffer(buffer, offset, bit_size, packed_type) |
0 commit comments