55from itertools import takewhile
66from dataclasses import dataclass
77
8- from typing import TYPE_CHECKING , Tuple , Optional
8+ from typing import TYPE_CHECKING , Tuple , Optional , List
99if TYPE_CHECKING : # Self doesn't exist <3.11
1010 from typing import Self
1111
@@ -24,9 +24,9 @@ class MSTNode:
2424 If a method is implemented in this class, it's because it's a function/property
2525 of a single node, as opposed to a whole tree
2626 """
27- keys : Tuple [str ] # collection/rkey
28- vals : Tuple [CID ] # record CIDs
29- subtrees : Tuple [Optional [CID ]] # a None value represents an empty subtree
27+ keys : Tuple [str , ... ] # collection/rkey
28+ vals : Tuple [CID , ... ] # record CIDs
29+ subtrees : Tuple [Optional [CID ], ... ] # a None value represents an empty subtree
3030
3131
3232 # NB: __init__ is auto-generated by dataclass decorator
@@ -86,17 +86,31 @@ def serialised(self) -> bytes:
8686 @classmethod
8787 def deserialise (cls , data : bytes ) -> "Self" :
8888 cbor = decode_dag_cbor (data )
89+ if not isinstance (cbor , dict ):
90+ raise ValueError ("malformed MST node" )
8991 if len (cbor ) != 2 : # e, l
9092 raise ValueError ("malformed MST node" )
91- subtrees = [cbor ["l" ]]
92- keys = []
93- vals = []
93+ l = cbor ["l" ]
94+ if not isinstance (l , (CID , None .__class__ )):
95+ raise ValueError ("malformed MST node" )
96+ subtrees : List [CID | None ] = [l ]
97+ keys : List [str ] = []
98+ vals : List [CID ] = []
9499 prev_key = b""
95- for e in cbor ["e" ]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations
100+ es = cbor ["e" ]
101+ if not isinstance (es , list ):
102+ raise ValueError ("malformed MST node" )
103+ for e in es : # TODO: make extra sure that these checks are watertight wrt non-canonical representations
104+ if not isinstance (e , dict ):
105+ raise ValueError ("malformed MST node" )
96106 if len (e ) != 4 : # k, p, t, v
97107 raise ValueError ("malformed MST node" )
98- prefix_len : int = e ["p" ]
99- suffix : bytes = e ["k" ]
108+ prefix_len = e ["p" ]
109+ if not isinstance (prefix_len , int ):
110+ raise ValueError ("malformed MST node" )
111+ suffix = e ["k" ]
112+ if not isinstance (suffix , bytes ):
113+ raise ValueError ("malformed MST node" )
100114 if prefix_len > len (prev_key ):
101115 raise ValueError ("invalid MST key prefix len" )
102116 if prev_key [prefix_len :prefix_len + 1 ] == suffix [:1 ]:
@@ -105,8 +119,14 @@ def deserialise(cls, data: bytes) -> "Self":
105119 if this_key <= prev_key :
106120 raise ValueError ("invalid MST key sort order" )
107121 keys .append (this_key .decode ())
108- vals .append (e ["v" ])
109- subtrees .append (e ["t" ])
122+ v = e ["v" ]
123+ if not isinstance (v , CID ):
124+ raise ValueError ("invalid MST key sort order" )
125+ vals .append (v )
126+ t = e ["t" ]
127+ if not isinstance (t , CID ):
128+ raise ValueError ("invalid MST key sort order" )
129+ subtrees .append (t )
110130 prev_key = this_key
111131
112132 return cls (
0 commit comments