Skip to content

Commit 6b6e97c

Browse files
committed
also handle positional instances
1 parent a517997 commit 6b6e97c

File tree

2 files changed

+179
-18
lines changed

2 files changed

+179
-18
lines changed

serialize_py/codegen_result.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
def get_root(_io=None, check=True):
88
if not _io:
99
_io = kaitaistruct.KaitaiStream(io.BytesIO(bytearray(root_size)))
10-
root = kaitaistruct_sqlite3.Sqlite3(_io)
10+
root = kaitaistruct_sqlite3.Sqlite3(_io=_io, _parent=None, _root=None)
1111
# try to fix root._write
1212
# https://github.com/kaitai-io/kaitai_struct/issues/1245
1313
root.pages__to_write = False
14-
root.header = kaitaistruct_sqlite3.Sqlite3.DatabaseHeader(root._io, root, root._root)
14+
root.header = kaitaistruct_sqlite3.Sqlite3.DatabaseHeader(_io=root._io, _parent=root, _root=root._root)
1515
def init_header(header):
1616
header.magic = b'SQLite format 3\x00'
1717
header.page_size_raw = 4096 # 0x1000
@@ -23,9 +23,10 @@ def init_header(header):
2323
header.leaf_payload_fraction = 32 # 0x20
2424
header.file_change_counter = 1
2525
header.num_pages = 2
26-
header.first_freelist_trunk_page = kaitaistruct_sqlite3.Sqlite3.FreelistTrunkPagePointer(root._io, header, header._root)
26+
header.first_freelist_trunk_page = kaitaistruct_sqlite3.Sqlite3.FreelistTrunkPagePointer(_io=root._io, _parent=header, _root=header._root)
2727
def init_first_freelist_trunk_page(first_freelist_trunk_page):
2828
first_freelist_trunk_page.page_number = 0
29+
first_freelist_trunk_page.page = None
2930
init_first_freelist_trunk_page(header.first_freelist_trunk_page)
3031
header.num_freelist_pages = 0
3132
header.schema_cookie = 1
@@ -40,6 +41,33 @@ def init_first_freelist_trunk_page(first_freelist_trunk_page):
4041
header.version_valid_for = 1
4142
header.sqlite_version_number = 3050001 # 0x2e8a11
4243
init_header(root.header)
44+
root.pages = []
45+
root.pages.append(kaitaistruct_sqlite3.Sqlite3.BtreePage(page_number=1, _io=root._io, _parent=root, _root=root._root))
46+
def init_page(page):
47+
page.page_type = kaitaistruct_sqlite3.Sqlite3.BtreePageType.table_leaf_page # 13 = 0xd
48+
page.first_freeblock = 0
49+
page.num_cells = 1
50+
page.ofs_cell_content_area_raw = 4044 # 0xfcc
51+
page.num_frag_free_bytes = 0
52+
page.cell_pointers = []
53+
page.cell_pointers.append(kaitaistruct_sqlite3.Sqlite3.CellPointer(_io=root._io, _parent=page, _root=page._root))
54+
def init_cell_pointer(cell_pointer):
55+
cell_pointer.ofs_content = 4044 # 0xfcc
56+
init_cell_pointer(page.cell_pointers[0])
57+
page.cell_content_area = b'2\x01\x06\x17\x15\x15\x01Itabletesttest\x02CREATE TABLE test (id INTEGER)'
58+
page.reserved_space = None
59+
init_page(root.pages[0])
60+
root.pages.append(kaitaistruct_sqlite3.Sqlite3.BtreePage(page_number=2, _io=root._io, _parent=root, _root=root._root))
61+
def init_page(page):
62+
page.page_type = kaitaistruct_sqlite3.Sqlite3.BtreePageType.table_leaf_page # 13 = 0xd
63+
page.first_freeblock = 0
64+
page.num_cells = 0
65+
page.ofs_cell_content_area_raw = 4096 # 0x1000
66+
page.num_frag_free_bytes = 0
67+
page.cell_pointers = []
68+
page.cell_content_area = b''
69+
page.reserved_space = None
70+
init_page(root.pages[1])
4371
if check:
4472
root._check()
4573
return root

serialize_py/kaitai_serialize_codegen.py

Lines changed: 148 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,24 @@ def get_seq(obj):
133133
continue
134134
return seq
135135

136+
def get_instances(obj):
137+
# TODO upstream: this should be simpler
138+
if not hasattr(obj, "_fetch_instances"):
139+
return []
140+
_fetch_instances = getattr(obj, "_fetch_instances")
141+
lines, firstlineno = inspect.getsourcelines(_fetch_instances)
142+
lines.pop(0) # "def _fetch_instances(self):"
143+
instances = []
144+
for line in lines:
145+
line = line.rstrip()
146+
# print("line", line)
147+
# line: _ = self.pages
148+
m = re.match(r"\s+_ = self\.(\w+)", line)
149+
if m:
150+
instances.append(m[1])
151+
continue
152+
return instances
153+
136154
def parse_enum_map(lines):
137155
enum_map = dict()
138156
line0 = lines.pop(0)
@@ -150,6 +168,8 @@ def parse_enum_map(lines):
150168
return enum_map
151169

152170
def get_local_key(key, global_names):
171+
# # handle array item keys like "some_array[123]"
172+
# key = key.replace("[", "_").replace("]", "_")
153173
num = 1
154174
local_key = key
155175
while local_key in global_names:
@@ -221,6 +241,9 @@ class E # 0 # E
221241
# class FormatVersion(IntEnum):
222242

223243

244+
debug_init_types = False
245+
246+
224247
def codegen(
225248
obj,
226249
out,
@@ -236,6 +259,7 @@ def codegen(
236259
module_map={},
237260
global_names=[],
238261
):
262+
print("codegen obj", obj)
239263
global val # fix print_value
240264
mod = obj.__class__.__module__
241265
# member = obj.__class__.__name__ # DatabaseHeader
@@ -272,7 +296,20 @@ def codegen(
272296
print(f"{ind}def get_{root_name}(_io=None, check=True):", file=out)
273297
print(f"{ind}{ids}if not _io:", file=out)
274298
print(f"{ind}{ids}{ids}_io = kaitaistruct.KaitaiStream(io.BytesIO(bytearray(root_size)))", file=out)
275-
print(f"{ind}{ids}{on} = {mod}.{member}(_io)", file=out)
299+
# TODO also pass parameters to root.__init__
300+
"""
301+
val_params = []
302+
if hasattr(val, "__init__"):
303+
val_init_sig = inspect.signature(val.__init__)
304+
# ...
305+
"""
306+
307+
# print(f"{ind}{ids}{on} = {mod}.{member}(_io=_io)", file=out)
308+
on_parent_root = f"{on_parent}._root" if on_parent else "None"
309+
print(f"{ind}{ids}{on} = {mod}.{member}(_io=_io, _parent={on_parent}, _root={on_parent_root})", file=out)
310+
311+
# print(f"{ind}{ids}assert {on}._root == {on}", file=out) # debug
312+
276313
# TODO remove. this works only for sqlite3.ksy
277314
print(f"{ind}{ids}# try to fix root._write", file=out)
278315
print(f"{ind}{ids}# https://github.com/kaitai-io/kaitai_struct/issues/1245", file=out)
@@ -281,9 +318,30 @@ def codegen(
281318
# else:
282319
# print(f"{ind}{ids}# non-root init", file=out)
283320
# print(f"{ind}{ids}{on} = {mod}.{member}(_io, {on_parent}, {on_parent}._root)", file=out)
284-
for key in get_seq(obj):
321+
# TODO? interleave "seq" and "instance" keys
322+
# TODO rename to seq_key?
323+
# for key in get_seq(obj):
324+
key_stack = get_seq(obj) + get_instances(obj)
325+
while key_stack:
326+
key = key_stack.pop(0)
285327
# print(f"{ind}{ids}# key {key}", file=out)
286-
val = getattr(obj, key)
328+
print("key", key) # debug
329+
val_is_list_item = False
330+
if key.endswith("]"):
331+
# val is a list item
332+
val_is_list_item = True
333+
m = re.fullmatch(r"(\w+)\[(\d+)\]", key)
334+
val_arr_name, val_arr_idx = m.groups()
335+
val_arr_idx = int(val_arr_idx)
336+
val_arr = getattr(obj, val_arr_name)
337+
val = val_arr[val_arr_idx]
338+
else:
339+
# FIXME get_seq also returns items where the "if" condition is false
340+
# val = getattr(obj, key)
341+
try:
342+
val = getattr(obj, key)
343+
except AttributeError:
344+
continue
287345
"""
288346
print("key", repr(key))
289347
print("val", repr(val), dir(val))
@@ -298,15 +356,30 @@ def codegen(
298356

299357
# builtin types: int, bytes, ...
300358
if mod == "builtins":
359+
if debug_init_types:
360+
print(f"{ind}{ids}# builtin type {type(val).__name__}", file=out)
301361
if isinstance(val, int) and val > 10:
302362
print(f"{ind}{ids}{on}.{key} = {val!r} # {hex(val)}", file=out)
303363
continue
304364
if isinstance(val, bytes) and val == len(val) * b"\x00":
305365
# compress null bytes
306366
# TODO partial compression of bytestrings
307-
print(f"{ind}{ids}{on}.{key} = {len(val)} * b'\\x00'", file=out)
367+
if len(val) == 0:
368+
print(f"{ind}{ids}{on}.{key} = b''", file=out)
369+
else:
370+
print(f"{ind}{ids}{on}.{key} = {len(val)} * b'\\x00'", file=out)
371+
continue
372+
if isinstance(val, list):
373+
print(f"{ind}{ids}{on}.{key} = []", file=out)
374+
new_keys = []
375+
for item_idx in range(len(val)):
376+
new_keys.append(f"{key}[{item_idx}]")
377+
# recursion via stack
378+
key_stack = new_keys + key_stack
379+
# TODO
380+
# print(f"{ind}{ids}{on}.{key}.append({xxxxxxx})", file=out)
308381
continue
309-
# bytes, ...
382+
# bytes, str, ...
310383
print(f"{ind}{ids}{on}.{key} = {val!r}", file=out)
311384
continue
312385

@@ -327,6 +400,8 @@ def codegen(
327400
m = re.match(r"\s*class (\w+)\(([A-Z][A-Za-z0-9]*Enum)\):", lines[0].rstrip())
328401
if m:
329402
enum_name, enum_type = m.groups()
403+
if debug_init_types:
404+
print(f"{ind}{ids}# enum type {enum_name}", file=out)
330405
enum_map = enum_map_map.get(enum_name) # read cache
331406
if not enum_map:
332407
enum_map = parse_enum_map(lines)
@@ -346,27 +421,58 @@ def codegen(
346421
print(f"{ind}{ids}{on}.{key} = {mod}.{enum_qualname}.{enum_key} # {val_str}", file=out)
347422
continue
348423

349-
# TODO handle list types
350-
# m = ...
351-
# if m:
352-
# ...
353-
# continue
354-
355424
# user-defined types
425+
if debug_init_types:
426+
print(f"{ind}{ids}# user-defined type {member}", file=out)
356427
# https://doc.kaitai.io/serialization.html#_user_defined_types
357428
# print(f"{ind}{ids}{on}.{key} = root.{member}(root._io, {on}, {on}._root)", file=out) # short
358429
# print(f"{ind}{ids}{on}.{key} = {mod}.{root_cln}.{member}(root._io, {on}, {on}._root)", file=out) # long
359-
print(f"{ind}{ids}{on}.{key} = {mod}.{member}(root._io, {on}, {on}._root)", file=out) # long
430+
# print(f"{ind}{ids}{on}.{key} = {mod}.{member}(root._io, {on}, {on}._root)", file=out) # long
431+
val_params = []
432+
if hasattr(val, "__init__"):
433+
val_init_sig = inspect.signature(val.__init__)
434+
if str(val_init_sig) != "(_io=None, _parent=None, _root=None)":
435+
# print("val_init_sig", repr(val_init_sig))
436+
# val.__init__ has extra args
437+
# example: page_number in "(page_number, _io=None, _parent=None, _root=None)"
438+
for param_name in val_init_sig.parameters.keys():
439+
# print(f"param_name {param_name}")
440+
if param_name in ("_io", "_parent", "_root"):
441+
continue
442+
# FIXME handle user-defined types via recursion
443+
# example:
444+
"""
445+
def get_page_number():
446+
# ...
447+
pages.append(BtreePage(page_number=get_page_number(), _io=root._io, _parent=root, _root=root._root))
448+
"""
449+
param_val = getattr(val, param_name)
450+
val_params.append(f"{param_name}={param_val}")
451+
val_params = "".join(map(lambda arg: arg + ", ", val_params))
452+
if val_is_list_item:
453+
print(f"{ind}{ids}{on}.{val_arr_name}.append({mod}.{member}({val_params}_io=root._io, _parent={on}, _root={on}._root))", file=out) # long
454+
else:
455+
print(f"{ind}{ids}{on}.{key} = {mod}.{member}({val_params}_io=root._io, _parent={on}, _root={on}._root)", file=out) # long
456+
def get_singular_name(plural_name):
457+
# vals -> val
458+
# val_list -> val
459+
if plural_name.endswith("_list"): return plural_name[:-5]
460+
if plural_name.endswith("_array"): return plural_name[:-6]
461+
if plural_name.endswith("s"): return plural_name[:-1]
462+
return plural_name
360463
# avoid shadowing global variables
361-
local_key = get_local_key(key, global_names)
464+
if val_is_list_item:
465+
local_key = get_local_key(get_singular_name(val_arr_name), global_names)
466+
else:
467+
local_key = get_local_key(key, global_names)
362468
# print(f"{ind}{ids}if 1:", file=out) # no block scope
363469
# print(f"{ind}{ids}if {local_key} := {on}.{key}:", file=out) # no block scope
364470
# TypeError: 'int' object does not support the context manager protocol
365471
# print(f"{ind}{ids}with {on}.{key} as {local_key}:", file=out) # context # no block scope?
366472
# create block scope
367473
# this is required to avoid name collisions between scopes
368474
# https://stackoverflow.com/a/45210833/10440128
369-
print(f"{ind}{ids}def init_{key}({local_key}):", file=out) # "init_" prefix
475+
print(f"{ind}{ids}def init_{local_key}({local_key}):", file=out) # "init_" prefix
370476
# print(f"{ind}{ids}def {key}_init({local_key}):", file=out) # "_init" suffix
371477
# recursion
372478
codegen(
@@ -382,9 +488,36 @@ def codegen(
382488
module_map,
383489
global_names,
384490
)
385-
print(f"{ind}{ids}init_{key}({on}.{key})", file=out) # "init_" prefix
491+
492+
if val_is_list_item:
493+
print(f"{ind}{ids}init_{local_key}({on}.{val_arr_name}[{val_arr_idx}])", file=out) # "init_" prefix
494+
else:
495+
print(f"{ind}{ids}init_{local_key}({on}.{key})", file=out) # "init_" prefix
496+
386497
# print(f"{ind}{ids}{key}_init({local_key})", file=out) # "_init" suffix
387498

499+
# for instance_key in get_instances(obj):
500+
if 0:
501+
# print(f"{ind}{ids}# instance_key {instance_key}", file=out)
502+
val = getattr(obj, instance_key)
503+
"""
504+
print("instance_key", repr(instance_key))
505+
print("val", repr(val), dir(val))
506+
print_value("val.__class__.__module__")
507+
print_value("val.__class__.__qualname__")
508+
"""
509+
# obj.__class__.__module__ == 'builtins'
510+
# TODO rename to "mod_name"
511+
mod = val.__class__.__module__
512+
# TODO rename to "member_name"
513+
member = val.__class__.__qualname__
514+
515+
print("obj", obj)
516+
print("FIXME instance_key", instance_key, val, mod, member)
517+
# FIXME instance_key page 0 builtins int
518+
# FIXME instance_key page None builtins NoneType
519+
raise 123
520+
388521
# some user-defined types need this
389522
# example: AttributeError: 'VlqBase128Be' object has no attribute 'groups'
390523
# but this breaks other cases...

0 commit comments

Comments
 (0)