Skip to content

Commit a8edcf7

Browse files
committed
fix: typing of deserialized enums.
- exhaustiveness checks
1 parent dc42e0f commit a8edcf7

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

dissect/target/helpers/sunrpc/serializer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import io
44
from abc import ABC, abstractmethod
5-
from enum import Enum
6-
from typing import Generic, TypeVar
5+
from enum import IntEnum
6+
from typing import Generic, Type, TypeVar, assert_never
77

88
from dissect.target.helpers.sunrpc import sunrpc
99

@@ -13,11 +13,11 @@
1313
Verifier = TypeVar("Verifier")
1414
Serializable = TypeVar("Serializable")
1515
AuthProtocol = TypeVar("AuthProtocol")
16-
EnumType = TypeVar("EN", bound=Enum)
16+
EnumType = TypeVar("EN", bound=IntEnum)
1717
ElementType = TypeVar("ET")
1818

1919

20-
class MessageType(Enum):
20+
class MessageType(IntEnum):
2121
CALL = 0
2222
REPLY = 1
2323

@@ -120,7 +120,7 @@ def _read_int32(self, payload: io.BytesIO) -> int:
120120
def _read_uint64(self, payload: io.BytesIO) -> int:
121121
return int.from_bytes(payload.read(8), byteorder="big", signed=False)
122122

123-
def _read_enum(self, payload: io.BytesIO, enum: EnumType) -> EnumType:
123+
def _read_enum(self, payload: io.BytesIO, enum: Type[EnumType]) -> EnumType:
124124
value = self._read_int32(payload)
125125
return enum(value)
126126

@@ -141,12 +141,12 @@ def _read_optional(self, payload: io.BytesIO, deserializer: Deserializer[Element
141141
return deserializer.deserialize(payload)
142142

143143

144-
class ReplyStat(Enum):
144+
class ReplyStat(IntEnum):
145145
MSG_ACCEPTED = 0
146146
MSG_DENIED = 1
147147

148148

149-
class AuthFlavor(Enum):
149+
class AuthFlavor(IntEnum):
150150
AUTH_NULL = 0
151151
AUTH_UNIX = 1
152152
AUTH_SHORT = 2
@@ -253,6 +253,8 @@ def deserialize(
253253
reply = self._read_accepted_reply(payload)
254254
elif reply_stat == ReplyStat.MSG_DENIED:
255255
reply = self._read_rejected_reply(payload)
256+
else:
257+
assert_never(reply_stat)
256258

257259
return sunrpc.Message(xid, reply)
258260

@@ -287,6 +289,8 @@ def _read_rejected_reply(self, payload: io.BytesIO) -> sunrpc.RejectedReply:
287289
elif reject_stat == sunrpc.RejectStat.AUTH_ERROR:
288290
auth_stat = self._read_enum(payload, sunrpc.AuthStat)
289291
return sunrpc.RejectedReply(reject_stat, auth_stat)
292+
else:
293+
assert_never(reject_stat)
290294

291295
def _read_mismatch(self, payload: io.BytesIO) -> sunrpc.Mismatch:
292296
low = self._read_uint32(payload)

0 commit comments

Comments
 (0)