Skip to content

Commit c567868

Browse files
Validate ocert counter, change encoding of validation failures
1 parent 0a6b93a commit c567868

File tree

5 files changed

+185
-128
lines changed

5 files changed

+185
-128
lines changed

dmq-node/src/DMQ/NodeToClient.hs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module DMQ.NodeToClient
1616
import Data.Aeson qualified as Aeson
1717
import Data.ByteString.Lazy (ByteString)
1818
import Data.Functor.Contravariant ((>$<))
19+
import Data.Typeable
1920
import Data.Void
2021
import Data.Word
2122

@@ -137,15 +138,17 @@ _ntc_MAX_SIGS_TO_ACK = 1000
137138
-- | Construct applications for the node-to-client protocols
138139
--
139140
ntcApps
140-
:: forall crypto idx ntcAddr failure m.
141+
:: forall crypto idx ntcAddr m.
141142
( MonadThrow m
142143
, MonadThread m
143144
, MonadSTM m
144145
, Crypto crypto
145146
, Aeson.ToJSON ntcAddr
146147
, Aeson.ToJSON (MempoolAddFail (Sig crypto))
148+
, Show (MempoolAddFail (Sig crypto))
147149
, ShowProxy (MempoolAddFail (Sig crypto))
148150
, ShowProxy (Sig crypto)
151+
, Typeable crypto
149152
)
150153
=> (forall ev. Aeson.ToJSON ev => Tracer m (WithEventType ev))
151154
-> Configuration
Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
{-# LANGUAGE FlexibleContexts #-}
2-
{-# LANGUAGE OverloadedStrings #-}
3-
{-# LANGUAGE StandaloneDeriving #-}
1+
{-# LANGUAGE FlexibleContexts #-}
2+
{-# LANGUAGE OverloadedStrings #-}
3+
{-# LANGUAGE StandaloneDeriving #-}
4+
{-# LANGUAGE TypeApplications #-}
5+
{-# LANGUAGE UndecidableInstances #-}
46

57
module DMQ.NodeToClient.LocalMsgSubmission where
68

79
import Control.Concurrent.Class.MonadSTM
10+
import Control.Monad.Class.MonadThrow
811
import Control.Tracer
912
import Data.Aeson (ToJSON (..), object, (.=))
1013
import Data.Aeson qualified as Aeson
14+
import Data.Typeable
1115

1216
import DMQ.Protocol.LocalMsgSubmission.Server
1317
import DMQ.Protocol.LocalMsgSubmission.Type
@@ -16,55 +20,80 @@ import Ouroboros.Network.TxSubmission.Mempool.Simple
1620
-- | Local transaction submission server, for adding txs to the 'Mempool'
1721
--
1822
localMsgSubmissionServer ::
19-
MonadSTM m
20-
=> (sig -> sigid)
23+
forall msgid msg idx m.
24+
( MonadSTM m
25+
, MonadThrow m
26+
, Typeable msgid
27+
, Typeable msg
28+
, Show msgid
29+
, Show (MempoolAddFail msg))
30+
=> (msg -> msgid)
2131
-- ^ get message id
22-
-> Tracer m (TraceLocalMsgSubmission sig sigid)
23-
-> MempoolWriter sigid sig failure idx m
32+
-> Tracer m (TraceLocalMsgSubmission msg msgid)
33+
-> MempoolWriter msgid msg idx m
2434
-- ^ duplicate error tag in case the mempool returns the empty list on failure
25-
-> m (LocalMsgSubmissionServer sig m ())
35+
-> m (LocalMsgSubmissionServer msg m ())
2636
localMsgSubmissionServer getMsgId tracer MempoolWriter { mempoolAddTxs } =
2737
pure server
2838
where
29-
process (sigid, e@(SubmitFail reason)) =
30-
(e, server) <$ traceWith tracer (TraceSubmitFailure sigid reason)
31-
process (sigid, success) =
32-
(success, server) <$ traceWith tracer (TraceSubmitAccept sigid)
39+
process (Left (msgid, reason)) = do
40+
(SubmitFail reason, server) <$ traceWith tracer (TraceSubmitFailure msgid reason)
41+
42+
process (Right [(msgid, e@(SubmitFail reason))]) =
43+
(e, server) <$ traceWith tracer (TraceSubmitFailure msgid reason)
44+
45+
process (Right [(msgid, SubmitSuccess)]) =
46+
(SubmitSuccess, server) <$ traceWith tracer (TraceSubmitAccept msgid)
47+
48+
process _ = throwIO (TooManyMessages @msgid @msg)
3349

3450
server = LocalTxSubmissionServer {
35-
recvMsgSubmitTx = \sig -> do
36-
traceWith tracer $ TraceReceivedMsg (getMsgId sig)
37-
process . head =<< mempoolAddTxs [sig]
51+
recvMsgSubmitTx = \msg -> do
52+
traceWith tracer $ TraceReceivedMsg (getMsgId msg)
53+
process =<< mempoolAddTxs [msg]
3854

3955
, recvMsgDone = ()
4056
}
4157

4258

43-
data TraceLocalMsgSubmission sig sigid =
44-
TraceReceivedMsg sigid
59+
data TraceLocalMsgSubmission msg msgid =
60+
TraceReceivedMsg msgid
4561
-- ^ A signature was received.
46-
| TraceSubmitFailure sigid (MempoolAddFail sig)
47-
| TraceSubmitAccept sigid
62+
| TraceSubmitFailure msgid (MempoolAddFail msg)
63+
| TraceSubmitAccept msgid
4864

4965
deriving instance
50-
(Show sig, Show sigid, Show (MempoolAddFail sig))
51-
=> Show (TraceLocalMsgSubmission sig sigid)
66+
(Show msg, Show msgid, Show (MempoolAddFail msg))
67+
=> Show (TraceLocalMsgSubmission msg msgid)
68+
69+
70+
71+
data MsgSubmissionServerException msgid msg =
72+
MsgValidationException msgid (MempoolAddFail msg)
73+
| TooManyMessages
74+
75+
deriving instance (Show (MempoolAddFail msg), Show msgid)
76+
=> Show (MsgSubmissionServerException msgid msg)
77+
78+
instance (Typeable msgid, Typeable msg, Show (MempoolAddFail msg), Show msgid)
79+
=> Exception (MsgSubmissionServerException msgid msg) where
80+
5281

53-
instance (ToJSON sigid, ToJSON (MempoolAddFail sig))
54-
=> ToJSON (TraceLocalMsgSubmission sig sigid) where
55-
toJSON (TraceReceivedMsg sigid) =
82+
instance (ToJSON msgid, ToJSON (MempoolAddFail msg))
83+
=> ToJSON (TraceLocalMsgSubmission msg msgid) where
84+
toJSON (TraceReceivedMsg msgid) =
5685
-- TODO: once we have verbosity levels, we could include the full tx, for
5786
-- now one can use `TraceSendRecv` tracer for the mini-protocol to see full
5887
-- msgs.
5988
object [ "kind" .= Aeson.String "TraceReceivedMsg"
60-
, "sigId" .= sigid
89+
, "sigId" .= msgid
6190
]
62-
toJSON (TraceSubmitFailure sigid reject) =
91+
toJSON (TraceSubmitFailure msgid reject) =
6392
object [ "kind" .= Aeson.String "TraceSubmitFailure"
64-
, "sigId" .= sigid
93+
, "sigId" .= msgid
6594
, "reason" .= reject
6695
]
67-
toJSON (TraceSubmitAccept sigid) =
96+
toJSON (TraceSubmitAccept msgid) =
6897
object [ "kind" .= Aeson.String "TraceSubmitAccept"
69-
, "sigId" .= sigid
98+
, "sigId" .= msgid
7099
]

dmq-node/src/DMQ/NodeToNode.hs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ import Codec.CBOR.Encoding qualified as CBOR
4141
import Codec.CBOR.Read qualified as CBOR
4242
import Codec.CBOR.Term qualified as CBOR
4343
import Data.Aeson qualified as Aeson
44-
import Data.ByteString qualified as BS
4544
import Data.ByteString.Lazy qualified as BL
4645
import Data.Functor.Contravariant ((>$<))
4746
import Data.Hashable (Hashable)
@@ -54,10 +53,7 @@ import Network.Mux.Types (Mode (..))
5453
import Network.Mux.Types qualified as Mx
5554
import Network.TypedProtocol.Codec (AnnotatedCodec, Codec)
5655

57-
import Cardano.Crypto.DSIGN.Class qualified as DSIGN
58-
import Cardano.Crypto.KES.Class qualified as KES
5956
import Cardano.KESAgent.KES.Crypto (Crypto (..))
60-
import Cardano.KESAgent.KES.OCert (OCertSignable)
6157

6258
import DMQ.Configuration (Configuration, Configuration' (..), I (..))
6359
import DMQ.Diffusion.NodeKernel (NodeKernel (..))
@@ -202,8 +198,7 @@ ntnApps
202198
fetchClientRegistry
203199
, peerSharingRegistry
204200
, peerSharingAPI
205-
, mempool
206-
, evolutionConfig
201+
-- , mempool
207202
, sigChannelVar
208203
, sigMempoolSem
209204
, sigSharedTxStateVar

dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Codec.hs

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ import Codec.CBOR.Encoding qualified as CBOR
1010
import Codec.CBOR.Read qualified as CBOR
1111
import Control.Monad.Class.MonadST
1212
import Data.ByteString.Lazy (ByteString)
13+
import Data.Text qualified as T
14+
import Data.Tuple (swap)
1315
import Text.Printf
1416

17+
import Cardano.Binary
1518
import Cardano.KESAgent.KES.Crypto (Crypto (..))
1619

1720
import DMQ.Protocol.LocalMsgSubmission.Type
@@ -35,18 +38,53 @@ codecLocalMsgSubmission =
3538

3639
encodeReject :: MempoolAddFail (Sig crypto) -> CBOR.Encoding
3740
encodeReject = \case
38-
SigInvalid reason -> CBOR.encodeListLen 2 <> CBOR.encodeWord 0 <> CBOR.encodeString reason
39-
SigDuplicate -> CBOR.encodeListLen 1 <> CBOR.encodeWord 1
40-
SigExpired -> CBOR.encodeListLen 1 <> CBOR.encodeWord 2
41+
SigInvalid reason -> CBOR.encodeListLen 2 <> CBOR.encodeWord8 0 <> e
42+
where
43+
e = case reason of
44+
InvalidKESSignature ocertKESPeriod sigKESPeriod err -> mconcat [
45+
CBOR.encodeListLen 4, CBOR.encodeWord8 0, toCBOR ocertKESPeriod, toCBOR sigKESPeriod, CBOR.encodeString (T.pack err)
46+
]
47+
InvalidSignatureOCERT ocertN sigKESPeriod err -> mconcat [
48+
CBOR.encodeListLen 4, CBOR.encodeWord8 1, CBOR.encodeWord64 ocertN, toCBOR sigKESPeriod, CBOR.encodeString (T.pack err)
49+
]
50+
KESBeforeStartOCERT startKESPeriod sigKESPeriod -> mconcat [
51+
CBOR.encodeListLen 3, CBOR.encodeWord8 2, toCBOR startKESPeriod, toCBOR sigKESPeriod
52+
]
53+
KESAfterEndOCERT endKESPeriod sigKESPeriod -> mconcat [
54+
CBOR.encodeListLen 3, CBOR.encodeWord8 3, toCBOR endKESPeriod, toCBOR sigKESPeriod
55+
]
56+
UnrecognizedPool -> CBOR.encodeListLen 1 <> CBOR.encodeWord8 4
57+
NotInitialized -> CBOR.encodeListLen 1 <> CBOR.encodeWord8 5
58+
ClockSkew -> CBOR.encodeListLen 1 <> CBOR.encodeWord8 6
59+
InvalidOCertCounter seen received
60+
-> mconcat
61+
[CBOR.encodeListLen 3, CBOR.encodeWord8 7, CBOR.encodeWord64 seen, CBOR.encodeWord64 received]
62+
SigDuplicate -> CBOR.encodeListLen 1 <> CBOR.encodeWord8 1
63+
SigExpired -> CBOR.encodeListLen 1 <> CBOR.encodeWord8 2
4164
SigResultOther reason
42-
-> CBOR.encodeListLen 2 <> CBOR.encodeWord 3 <> CBOR.encodeString reason
65+
-> CBOR.encodeListLen 2 <> CBOR.encodeWord8 3 <> CBOR.encodeString reason
4366

4467
decodeReject :: CBOR.Decoder s (MempoolAddFail (Sig crypto))
4568
decodeReject = do
4669
len <- CBOR.decodeListLen
47-
tag <- CBOR.decodeWord
70+
tag <- CBOR.decodeWord8
4871
case (tag, len) of
49-
(0, 2) -> SigInvalid <$> CBOR.decodeString
72+
(0, 2) -> SigInvalid <$> decSigValidError
73+
where
74+
decSigValidError :: CBOR.Decoder s SigValidationError
75+
decSigValidError = do
76+
lenTag <- (,) <$> CBOR.decodeListLen <*> CBOR.decodeWord8
77+
case swap lenTag of
78+
(0, 4) -> InvalidKESSignature <$> fromCBOR <*> fromCBOR <*> (T.unpack <$> CBOR.decodeString)
79+
(1, 4) -> InvalidSignatureOCERT <$> CBOR.decodeWord64 <*> fromCBOR <*> (T.unpack <$> CBOR.decodeString)
80+
(2, 3) -> KESBeforeStartOCERT <$> fromCBOR <*> fromCBOR
81+
(3, 4) -> KESAfterEndOCERT <$> fromCBOR <*> fromCBOR
82+
(4, 1) -> pure UnrecognizedPool
83+
(5, 1) -> pure NotInitialized
84+
(6, 1) -> pure ClockSkew
85+
(7, 3) -> InvalidOCertCounter <$> fromCBOR <*> fromCBOR
86+
_otherwise -> fail $ printf "unrecognized (tag,len) = (%d, %d) when decoding SigInvalid tag" tag len
87+
5088
(1, 1) -> pure SigDuplicate
5189
(2, 1) -> pure SigExpired
5290
(3, 2) -> SigResultOther <$> CBOR.decodeString

0 commit comments

Comments
 (0)