Skip to content

Commit 0e5d5bd

Browse files
fix linter type errors
1 parent 421a274 commit 0e5d5bd

File tree

4 files changed

+38
-16
lines changed

4 files changed

+38
-16
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
22
"python.testing.unittestEnabled": true,
3-
"python.testing.pytestEnabled": false
3+
"python.testing.pytestEnabled": false,
4+
"python-envs.defaultEnvManager": "ms-python.python:system"
45
}

src/atmst/mst/node.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from itertools import takewhile
66
from dataclasses import dataclass
77

8-
from typing import TYPE_CHECKING, Tuple, Optional
8+
from typing import TYPE_CHECKING, Tuple, Optional, List
99
if TYPE_CHECKING: # Self doesn't exist <3.11
1010
from typing import Self
1111

@@ -24,9 +24,9 @@ class MSTNode:
2424
If a method is implemented in this class, it's because it's a function/property
2525
of a single node, as opposed to a whole tree
2626
"""
27-
keys: Tuple[str] # collection/rkey
28-
vals: Tuple[CID] # record CIDs
29-
subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree
27+
keys: Tuple[str, ...] # collection/rkey
28+
vals: Tuple[CID, ...] # record CIDs
29+
subtrees: Tuple[Optional[CID], ...] # a None value represents an empty subtree
3030

3131

3232
# NB: __init__ is auto-generated by dataclass decorator
@@ -86,17 +86,31 @@ def serialised(self) -> bytes:
8686
@classmethod
8787
def deserialise(cls, data: bytes) -> "Self":
8888
cbor = decode_dag_cbor(data)
89+
if not isinstance(cbor, dict):
90+
raise ValueError("malformed MST node")
8991
if len(cbor) != 2: # e, l
9092
raise ValueError("malformed MST node")
91-
subtrees = [cbor["l"]]
92-
keys = []
93-
vals = []
93+
l = cbor["l"]
94+
if not isinstance(l, (CID, None.__class__)):
95+
raise ValueError("malformed MST node")
96+
subtrees: List[CID | None] = [l]
97+
keys: List[str] = []
98+
vals: List[CID] = []
9499
prev_key = b""
95-
for e in cbor["e"]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations
100+
es = cbor["e"]
101+
if not isinstance(es, list):
102+
raise ValueError("malformed MST node")
103+
for e in es: # TODO: make extra sure that these checks are watertight wrt non-canonical representations
104+
if not isinstance(e, dict):
105+
raise ValueError("malformed MST node")
96106
if len(e) != 4: # k, p, t, v
97107
raise ValueError("malformed MST node")
98-
prefix_len: int = e["p"]
99-
suffix: bytes = e["k"]
108+
prefix_len = e["p"]
109+
if not isinstance(prefix_len, int):
110+
raise ValueError("malformed MST node")
111+
suffix = e["k"]
112+
if not isinstance(suffix, bytes):
113+
raise ValueError("malformed MST node")
100114
if prefix_len > len(prev_key):
101115
raise ValueError("invalid MST key prefix len")
102116
if prev_key[prefix_len:prefix_len+1] == suffix[:1]:
@@ -105,8 +119,14 @@ def deserialise(cls, data: bytes) -> "Self":
105119
if this_key <= prev_key:
106120
raise ValueError("invalid MST key sort order")
107121
keys.append(this_key.decode())
108-
vals.append(e["v"])
109-
subtrees.append(e["t"])
122+
v = e["v"]
123+
if not isinstance(v, CID):
124+
raise ValueError("invalid MST key sort order")
125+
vals.append(v)
126+
t = e["t"]
127+
if not isinstance(t, CID):
128+
raise ValueError("invalid MST key sort order")
129+
subtrees.append(t)
110130
prev_key = this_key
111131

112132
return cls(

src/atmst/mst/node_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class NodeStore:
1414
for loading and storing MSTNodes
1515
"""
1616
bs: BlockStore
17-
cache: Dict[Optional[CID], MSTNode]
17+
cache: LRU[Optional[CID], MSTNode]
1818

1919
def __init__(self, bs: BlockStore) -> None:
2020
self.bs = bs

src/atmst/mst/node_walker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self,
5050
self.ns = ns
5151
self.trusted = trusted
5252
node = MSTNode.empty_root() if root_cid is None else self.ns.get_node(root_cid)
53-
self.root_height = node.maybe_height if root_height is None else root_height
53+
self.root_height = node.definitely_height() if root_height is None else root_height
5454
if self.root_height is None:
5555
raise ValueError("indeterminate node height - pass it in if you know it")
5656
self.stack = [self.StackFrame(
@@ -61,7 +61,7 @@ def __init__(self,
6161
)]
6262

6363
def subtree_walker(self) -> "Self":
64-
return NodeWalker(
64+
return self.__class__(
6565
self.ns,
6666
self.subtree,
6767
self.lpath,
@@ -147,6 +147,7 @@ def next_kv(self) -> Tuple[str, CID]:
147147
while self.subtree: # recurse down every subtree
148148
self.down()
149149
self.right_or_up()
150+
assert self.lval is not None
150151
return self.lpath, self.lval # the kv pair we just jumped over
151152

152153
# iterate over every k/v pair in key-sorted order

0 commit comments

Comments
 (0)