22{-# LANGUAGE FlexibleInstances #-}
33{-# LANGUAGE MultiWayIf #-}
44{-# LANGUAGE OverloadedStrings #-}
5- {-# LANGUAGE StandaloneDeriving #-}
5+ {-# LANGUAGE PatternSynonyms #-}
66{-# LANGUAGE TupleSections #-}
77{-# LANGUAGE TypeFamilies #-}
88{-# LANGUAGE TypeOperators #-}
9+ {-# LANGUAGE ViewPatterns #-}
10+
11+ {-# OPTIONS_GHC -fno-warn-orphans #-}
912
1013-- | Encapsulates signature validation utilities leveraged by the mempool writer
1114--
1215module DMQ.Protocol.SigSubmission.Validate where
1316
1417import Control.Monad
1518import Control.Concurrent.Class.MonadSTM.Strict
19+ import Control.Exception (Exception )
1620import Control.Monad.Class.MonadTime.SI
1721import Control.Monad.Trans.Class
1822import Control.Monad.Trans.Except
@@ -45,16 +49,18 @@ import Ouroboros.Network.Util.ShowProxy
4549-- | The type of non-fatal failures reported by the mempool writer
4650-- for invalid messages
4751--
48- data instance MempoolAddFail (Sig crypto ) =
52+ data instance TxValidationFail (Sig crypto ) =
4953 SigInvalid SigValidationError
5054 | SigDuplicate
5155 | SigExpired
5256 | SigResultOther Text
5357 deriving (Eq , Show )
5458
55- instance (Typeable crypto ) => ShowProxy (MempoolAddFail (Sig crypto ))
59+ instance (Typeable crypto ) => ShowProxy (TxValidationFail (Sig crypto ))
60+
61+ instance (Typeable crypto ) => Exception (TxValidationFail (Sig crypto ))
5662
57- instance ToJSON (MempoolAddFail (Sig crypto )) where
63+ instance ToJSON (TxValidationFail (Sig crypto )) where
5864 toJSON SigDuplicate = String " duplicate"
5965 toJSON SigExpired = String " expired"
6066 toJSON (SigInvalid e) = object
@@ -84,6 +90,21 @@ data SigValidationError =
8490 deriving (Eq , Show )
8591
8692
93+ c_MAX_CLOCK_SKEW_SEC :: NominalDiffTime
94+ c_MAX_CLOCK_SKEW_SEC = 5
95+
96+ pattern NotZeroSetSnapshot :: StakeSnapshot
97+ pattern NotZeroSetSnapshot <- (isZero . ssSetPool -> False )
98+
99+ pattern NotZeroMarkSnapshot :: StakeSnapshot
100+ pattern NotZeroMarkSnapshot <- (isZero . ssMarkPool -> False )
101+
102+ pattern ZeroSetSnapshot :: StakeSnapshot
103+ pattern ZeroSetSnapshot <- (isZero . ssSetPool -> True )
104+
105+ {-# COMPLETE NotZeroSetSnapshot, NotZeroMarkSnapshot, ZeroSetSnapshot #-}
106+
107+
87108-- TODO:
88109-- We don't validate ocert numbers, since we might not have necessary
89110-- information to do so, but we can validate that they are growing.
@@ -99,9 +120,9 @@ validateSig :: forall crypto m.
99120 -> [Sig crypto ]
100121 -> PoolValidationCtx m
101122 -- ^ cardano pool id verification
102- -> ExceptT (Sig crypto , MempoolAddFail (Sig crypto )) m
103- [(Sig crypto , Either (MempoolAddFail (Sig crypto )) () )]
104- validateSig _ec verKeyHashingFn sigs ctx = traverse process' sigs
123+ -> ExceptT (Sig crypto , TxValidationFail (Sig crypto )) m
124+ [(Sig crypto , Either (TxValidationFail (Sig crypto )) () )]
125+ validateSig verKeyHashingFn sigs ctx = traverse process' sigs
105126 where
106127 DMQPoolValidationCtx now mNextEpoch pools ocertCountersVar = ctx
107128
@@ -123,33 +144,31 @@ validateSig _ec verKeyHashingFn sigs ctx = traverse process' sigs
123144 ?! KESBeforeStartOCERT startKESPeriod sigKESPeriod
124145 e <- case Map. lookup (verKeyHashingFn coldKey) pools of
125146 Nothing | isNothing mNextEpoch
126- -> invalid SigResultOther $ Text. pack " not initialized yet"
147+ -> right . Left . SigResultOther $ Text. pack " not initialized yet"
127148 | otherwise
128149 -> left $ SigInvalid UnrecognizedPool
129- -- TODO make 5 a constant
130- Just ss | not (isZero (ssSetPool ss)) ->
150+ Just ss | NotZeroSetSnapshot <- ss ->
131151 if | now < nextEpoch -> success
132152 -- localstatequery is late, but the pool is about to expire
133153 | isZero (ssMarkPool ss)
134- , now > addUTCTime 5 nextEpoch -> left SigExpired
154+ , now > addUTCTime c_MAX_CLOCK_SKEW_SEC nextEpoch -> left SigExpired
135155 -- we bound the time we're willing to approve a message
136156 -- in case smth happened to localstatequery and it's taking
137157 -- too long to update our state
138- | now <= addUTCTime 5 nextEpoch -> success
139- | otherwise -> left $ SigInvalid ClockSkew
140- | not (isZero (ssMarkPool ss)) ->
158+ | now <= addUTCTime c_MAX_CLOCK_SKEW_SEC nextEpoch -> success
159+ | otherwise -> right . Left $ SigInvalid ClockSkew
160+ | NotZeroMarkSnapshot <- ss ->
141161 -- we take abs time in case we're late with our own
142162 -- localstatequery update, and/or the other side's clock
143163 -- is ahead, and we're just about or have just crossed the epoch
144164 -- and the pool is expected to move into the set mark
145- if | abs (diffUTCTime nextEpoch now) <= 5 -> success
146- | diffUTCTime nextEpoch now > 5 ->
165+ if | abs (diffUTCTime nextEpoch now) <= c_MAX_CLOCK_SKEW_SEC -> success
166+ | diffUTCTime nextEpoch now > c_MAX_CLOCK_SKEW_SEC ->
147167 left . SigResultOther $ Text. pack " pool not eligible yet"
148168 | otherwise -> right . Left $ SigInvalid ClockSkew
149169 -- pool is deregistered and ineligible to mint blocks
150- | isZero (ssSetPool ss) ->
170+ | ZeroSetSnapshot <- ss ->
151171 left SigExpired
152- | otherwise -> error " validateSig: impossible pool validation error"
153172 where
154173 -- mNextEpoch and pools are initialized in one STM transaction
155174 -- and fromJust will not fail here
@@ -167,15 +186,14 @@ validateSig _ec verKeyHashingFn sigs ctx = traverse process' sigs
167186 let f = \ case
168187 Nothing -> Right $ Just ocertN
169188 Just n | n <= ocertN -> Right $ Just ocertN
170- | otherwise -> Left . throwE . SigInvalid $ InvalidOCertCounter n ocertN
189+ | otherwise -> Left $ InvalidOCertCounter n ocertN
171190 in case Map. alterF f (verKeyHashingFn coldKey) ocertCounters of
172191 Right ocertCounters' -> (void success, ocertCounters')
173- Left err -> (err, ocertCounters)
192+ Left err -> (throwE ( SigInvalid err) , ocertCounters)
174193 -- for eg. remember to run all results with possibly non-fatal errors
175194 right e
176195 where
177196 success = right $ Right ()
178- invalid tag = right . Left . tag
179197
180198 startKESPeriod , endKESPeriod :: KESPeriod
181199
@@ -187,12 +205,12 @@ validateSig _ec verKeyHashingFn sigs ctx = traverse process' sigs
187205
188206 (?!:) :: Either e1 ()
189207 -> (e1 -> SigValidationError )
190- -> ExceptT (MempoolAddFail (Sig crypto )) m ()
208+ -> ExceptT (TxValidationFail (Sig crypto )) m ()
191209 (?!:) result f = firstExceptT (SigInvalid . f) . hoistEither $ result
192210
193211 (?!) :: Bool
194212 -> SigValidationError
195- -> ExceptT (MempoolAddFail (Sig crypto )) m ()
213+ -> ExceptT (TxValidationFail (Sig crypto )) m ()
196214 (?!) flag sve = if flag then void success else left (SigInvalid sve)
197215
198216 infix 1 ?!
0 commit comments