Skip to content

Commit 5142dc8

Browse files
niftyneirustyrussell
authored andcommitted
pyln-proto: write out length of arrays of subtypes to wire
We weren't writing out the length of a nested subtype's dynamicarraylenght, now we do. The trick is to iterate through the fields on a subtype (since the length field is added separately) and to also iterate down through the otherfield values as we 'descend'
1 parent 6db6ba6 commit 5142dc8

File tree

3 files changed

+45
-8
lines changed

3 files changed

+45
-8
lines changed

contrib/pyln-proto/pyln/proto/message/array_types.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,16 @@ def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Union[str, List[Any]
4848

4949
return [self.elemtype.val_to_py(i, otherfields) for i in v]
5050

51-
def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
52-
for i in v:
53-
self.elemtype.write(io_out, i, otherfields)
51+
def write(self, io_out: BufferedIOBase, vals: List[Any], otherfields: Dict[str, Any]) -> None:
52+
name = self.name.split('.')[1]
53+
if otherfields and name in otherfields:
54+
otherfields = otherfields[name]
55+
for i, val in enumerate(vals):
56+
if isinstance(otherfields, list) and len(otherfields) > i:
57+
fields = otherfields[i]
58+
else:
59+
fields = otherfields
60+
self.elemtype.write(io_out, val, fields)
5461

5562
def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]:
5663
"""arraysize None means take rest of io entirely and exactly"""
@@ -179,7 +186,7 @@ def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[st
179186
if mylen != len(otherfields[lens.name]):
180187
return [fieldname]
181188
# Field might be missing!
182-
if lens.name in otherfields:
189+
if otherfields and lens.name in otherfields:
183190
mylen = len(otherfields[lens.name])
184191
return []
185192

contrib/pyln-proto/pyln/proto/message/message.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,17 @@ def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[st
297297

298298
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
299299
self._raise_if_badvals(v)
300-
for fname, val in v.items():
301-
field = self.find_field(fname)
302-
assert field
303-
field.fieldtype.write(io_out, val, otherfields)
300+
for f in self.fields:
301+
if f.name in v:
302+
val = v[f.name]
303+
else:
304+
if f.option is not None:
305+
raise ValueError("Missing field {} {}".format(f.name, otherfields))
306+
val = None
307+
308+
if self.name in otherfields:
309+
otherfields = otherfields[self.name]
310+
f.fieldtype.write(io_out, val, otherfields)
304311

305312
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
306313
vals = {}

contrib/pyln-proto/tests/test_message.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,29 @@ def test_subtype():
9090
assert m.missing_fields()
9191

9292

93+
def test_subtype_array():
94+
ns = MessageNamespace()
95+
ns.load_csv(['msgtype,tx_signatures,1',
96+
'msgdata,tx_signatures,num_witnesses,u16,',
97+
'msgdata,tx_signatures,witness_stack,witness_stack,num_witnesses',
98+
'subtype,witness_stack',
99+
'subtypedata,witness_stack,num_input_witness,u16,',
100+
'subtypedata,witness_stack,witness_element,witness_element,num_input_witness',
101+
'subtype,witness_element',
102+
'subtypedata,witness_element,len,u16,',
103+
'subtypedata,witness_element,witness,byte,len'])
104+
105+
for test in [["tx_signatures witness_stack="
106+
"[{witness_element=[{witness=3045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01},{witness=02d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b}]}]",
107+
bytes.fromhex('00010001000200483045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01002102d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b')]]:
108+
m = Message.from_str(ns, test[0])
109+
assert m.to_str() == test[0]
110+
buf = io.BytesIO()
111+
m.write(buf)
112+
assert buf.getvalue().hex() == test[1].hex()
113+
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
114+
115+
93116
def test_tlv():
94117
ns = MessageNamespace()
95118
ns.load_csv(['msgtype,test1,1',

0 commit comments

Comments
 (0)