diff --git a/protocol/localtxmonitor/messages.go b/protocol/localtxmonitor/messages.go index eae8cd19..5f7c0eb9 100644 --- a/protocol/localtxmonitor/messages.go +++ b/protocol/localtxmonitor/messages.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2025 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package localtxmonitor import ( "fmt" + "math" "github.com/blinklabs-io/gouroboros/cbor" "github.com/blinklabs-io/gouroboros/protocol" @@ -170,13 +171,21 @@ func (m *MsgReplyNextTx) UnmarshalCBOR(data []byte) error { if tmp == nil { return nil } + messageType64 := tmp[0].(uint64) + if messageType64 > math.MaxUint8 { + return fmt.Errorf("message type integer overflow") + } // We know what the value will be, but it doesn't hurt to use the actual value from the message - m.MessageType = uint8(tmp[0].(uint64)) + m.MessageType = uint8(messageType64) // The ReplyNextTx message has a variable number of arguments if len(tmp) > 1 { txWrapper := tmp[1].([]interface{}) + eraId64 := txWrapper[0].(uint64) + if eraId64 > math.MaxUint8 { + return fmt.Errorf("era id integer overflow") + } m.Transaction = MsgReplyNextTxTransaction{ - EraId: uint8(txWrapper[0].(uint64)), + EraId: uint8(eraId64), Tx: txWrapper[1].(cbor.WrappedCbor).Bytes(), } } diff --git a/protocol/localtxmonitor/server.go b/protocol/localtxmonitor/server.go index e951f110..5341708f 100644 --- a/protocol/localtxmonitor/server.go +++ b/protocol/localtxmonitor/server.go @@ -1,4 +1,4 @@ -// Copyright 2024 Blink Labs Software +// Copyright 2025 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package localtxmonitor import ( "encoding/hex" "fmt" + "math" "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" @@ -193,6 +194,9 @@ func (s *Server) handleNextTx() error { return nil } mempoolTx := s.mempoolTxs[s.mempoolNextTxIdx] + if mempoolTx.EraId > math.MaxUint8 { + return fmt.Errorf("integer overflow in era id") + } newMsg := NewMsgReplyNextTx(uint8(mempoolTx.EraId), mempoolTx.Tx) if err := s.SendMessage(newMsg); err != nil { return err @@ -213,10 +217,18 @@ func (s *Server) handleGetSizes() error { for _, tx := range s.mempoolTxs { totalTxSize += len(tx.Tx) } + numTxs := len(s.mempoolTxs) + // check for over/underflows + if totalTxSize < 0 || totalTxSize > math.MaxUint32 { + return fmt.Errorf("integrer overflow in total tx size") + } + if numTxs < 0 || numTxs > math.MaxUint32 { + return fmt.Errorf("integrer overflow in tx count") + } newMsg := NewMsgReplyGetSizes( s.mempoolCapacity, - uint32(totalTxSize), - uint32(len(s.mempoolTxs)), + uint32(totalTxSize), // #nosec G115 + uint32(numTxs), ) if err := s.SendMessage(newMsg); err != nil { return err