Skip to content

Commit 92ede66

Browse files
committed
fix: correctly wrap rejected reply
1 parent 4cdb8ad commit 92ede66

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

dissect/target/helpers/sunrpc/serializer.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _flavor(self) -> int:
190190
return AuthFlavor.AUTH_NULL.value
191191

192192
def _write_body(self, _: AuthProtocol) -> bytes:
193-
return bytes()
193+
return b""
194194

195195
def _read_body(self, _: io.BytesIO) -> AuthProtocol:
196196
return sunrpc.AuthNull()
@@ -235,30 +235,28 @@ def __init__(
235235
self._verifierSerializer = verifierSerializer
236236

237237
def serialize(self, message: sunrpc.Message[ProcedureParams, ProcedureResults, Credentials, Verifier]) -> bytes:
238-
result = self._write_uint32(message.xid)
239-
240-
if isinstance(message.body, sunrpc.CallBody):
241-
result += self._write_enum(MessageType.CALL)
242-
return result + self._write_call_body(message.body)
238+
if not isinstance(message.body, sunrpc.CallBody):
239+
raise NotImplementedError("Only CALL messages are serializable")
243240

244-
raise NotImplementedError("Only CALL messages are serializable")
241+
result = self._write_uint32(message.xid)
242+
result += self._write_enum(MessageType.CALL)
243+
return result + self._write_call_body(message.body)
245244

246245
def deserialize(
247246
self, payload: io.BytesIO
248247
) -> sunrpc.Message[ProcedureParams, ProcedureResults, Credentials, Verifier]:
249248
xid = self._read_uint32(payload)
250-
messageType = self._read_enum(payload, MessageType)
251-
252-
if messageType == MessageType.REPLY:
253-
replyStat = self._read_enum(payload, ReplyStat)
254-
if replyStat == ReplyStat.MSG_ACCEPTED:
255-
reply = self._read_accepted_reply(payload)
256-
elif replyStat == ReplyStat.MSG_DENIED:
257-
reply = self._read_rejected_reply(payload)
249+
message_type = self._read_enum(payload, MessageType)
250+
if message_type != MessageType.REPLY:
251+
raise NotImplementedError("Only REPLY messages are deserializable")
258252

259-
return sunrpc.Message(xid, reply)
253+
reply_stat = self._read_enum(payload, ReplyStat)
254+
if reply_stat == ReplyStat.MSG_ACCEPTED:
255+
reply = self._read_accepted_reply(payload)
256+
elif reply_stat == ReplyStat.MSG_DENIED:
257+
reply = self._read_rejected_reply(payload)
260258

261-
raise NotImplementedError("Only REPLY messages are deserializable")
259+
return sunrpc.Message(xid, reply)
262260

263261
def _write_call_body(self, call_body: sunrpc.CallBody) -> bytes:
264262
result = self._write_uint32(call_body.rpc_version)
@@ -286,10 +284,11 @@ def _read_accepted_reply(self, payload: io.BytesIO) -> sunrpc.AcceptedReply[Proc
286284
def _read_rejected_reply(self, payload: io.BytesIO) -> sunrpc.RejectedReply:
287285
reject_stat = self._read_enum(payload, sunrpc.RejectStat)
288286
if reject_stat == sunrpc.RejectStat.RPC_MISMATCH:
289-
return self._read_mismatch(payload)
287+
mismatch = self._read_mismatch(payload)
288+
return sunrpc.RejectedReply(reject_stat, mismatch)
290289
elif reject_stat == sunrpc.RejectStat.AUTH_ERROR:
291290
auth_stat = self._read_enum(payload, sunrpc.AuthStat)
292-
return auth_stat
291+
return sunrpc.RejectedReply(reject_stat, auth_stat)
293292

294293
def _read_mismatch(self, payload: io.BytesIO) -> sunrpc.Mismatch:
295294
low = self._read_uint32(payload)
@@ -298,9 +297,9 @@ def _read_mismatch(self, payload: io.BytesIO) -> sunrpc.Mismatch:
298297

299298

300299
class PortMappingSerializer(Serializer[sunrpc.PortMapping]):
301-
def serialize(self, portMapping: sunrpc.PortMapping) -> bytes:
302-
result = self._write_uint32(portMapping.program)
303-
result += self._write_uint32(portMapping.version)
304-
result += self._write_enum(portMapping.protocol)
305-
result += self._write_uint32(portMapping.port)
300+
def serialize(self, port_mapping: sunrpc.PortMapping) -> bytes:
301+
result = self._write_uint32(port_mapping.program)
302+
result += self._write_uint32(port_mapping.version)
303+
result += self._write_enum(port_mapping.protocol)
304+
result += self._write_uint32(port_mapping.port)
306305
return result

0 commit comments

Comments
 (0)