1
1
{-# LANGUAGE AllowAmbiguousTypes #-}
2
+ {-# LANGUAGE CPP #-}
2
3
{-# LANGUAGE DataKinds #-}
3
4
{-# LANGUAGE DerivingVia #-}
4
5
{-# LANGUAGE GADTs #-}
5
6
{-# LANGUAGE LambdaCase #-}
6
7
{-# LANGUAGE OverloadedStrings #-}
7
8
{-# LANGUAGE PatternSynonyms #-}
9
+ {-# LANGUAGE ScopedTypeVariables #-}
8
10
{-# LANGUAGE ViewPatterns #-}
9
11
10
12
-- | Generate example CBOR given a CDDL specification
11
13
module Codec.CBOR.Cuddle.CBOR.Gen (generateCBORTerm , generateCBORTerm' ) where
12
14
15
+ import qualified Control.Monad.State.Strict as MTL
13
16
import Capability.Reader
14
17
import Capability.Sink (HasSink )
15
18
import Capability.Source (HasSource , MonadState (.. ))
16
- import Capability.State (HasState , get , modify , state )
19
+ import Capability.State (HasState , get , modify )
17
20
import Codec.CBOR.Cuddle.CDDL (
18
21
Name (.. ),
19
22
OccurrenceIndicator (.. ),
@@ -45,11 +48,9 @@ import Data.Word (Word32, Word64)
45
48
import GHC.Generics (Generic )
46
49
import System.Random.Stateful (
47
50
Random ,
48
- RandomGen (genShortByteString , genWord32 , genWord64 ),
49
- RandomGenM ,
50
- StatefulGen (.. ),
51
+ RandomGen (.. ),
52
+ StateGenM (.. ),
51
53
UniformRange (uniformRM ),
52
- applyRandomGenM ,
53
54
randomM ,
54
55
uniformByteStringM ,
55
56
)
@@ -59,10 +60,8 @@ import System.Random.Stateful (
59
60
--------------------------------------------------------------------------------
60
61
61
62
-- | Generator context, parametrised over the type of the random seed
62
- data GenEnv g = GenEnv
63
+ newtype GenEnv = GenEnv
63
64
{ cddl :: CTreeRoot' Identity MonoRef
64
- , fakeSeed :: CapGenM g
65
- -- ^ Access the "fake" seed, necessary to recursively call generators
66
65
}
67
66
deriving (Generic )
68
67
@@ -77,63 +76,55 @@ data GenState g = GenState
77
76
}
78
77
deriving (Generic )
79
78
80
- newtype M g a = M { runM :: StateT (GenState g ) (Reader (GenEnv g )) a }
81
- deriving (Functor , Applicative , Monad )
79
+ instance RandomGen g => RandomGen (GenState g ) where
80
+ genWord8 = withRandomSeed genWord8
81
+ genWord16 = withRandomSeed genWord16
82
+ genWord32 = withRandomSeed genWord32
83
+ genWord64 = withRandomSeed genWord64
84
+ split s =
85
+ case split (randomSeed s) of
86
+ (gen', gen) -> (s {randomSeed = gen'}, s {randomSeed = gen})
87
+
88
+ withRandomSeed :: (t -> (a , g )) -> GenState t -> (a , GenState g )
89
+ withRandomSeed f s =
90
+ case f (randomSeed s) of
91
+ (r, gen) -> (r, s {randomSeed = gen})
92
+
93
+ newtype M g a = M { runM :: StateT (GenState g ) (Reader GenEnv ) a }
94
+ deriving (Functor , Applicative , Monad , MTL.MonadState (GenState g))
82
95
deriving
83
96
(HasSource " randomSeed" g , HasSink " randomSeed" g , HasState " randomSeed" g )
84
97
via Field
85
98
" randomSeed"
86
99
()
87
- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
100
+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
88
101
deriving
89
102
(HasSource " depth" Int , HasSink " depth" Int , HasState " depth" Int )
90
103
via Field
91
104
" depth"
92
105
()
93
- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
106
+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
94
107
deriving
95
108
( HasSource " cddl" (CTreeRoot' Identity MonoRef )
96
109
, HasReader " cddl" (CTreeRoot' Identity MonoRef )
97
110
)
98
111
via Field
99
112
" cddl"
100
113
()
101
- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
102
- deriving
103
- (HasSource " fakeSeed" (CapGenM g ), HasReader " fakeSeed" (CapGenM g ))
104
- via Field
105
- " fakeSeed"
106
- ()
107
- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
108
-
109
- -- | Opaque type carrying the type of a pure PRNG inside a capability-style
110
- -- state monad.
111
- data CapGenM g = CapGenM
114
+ (Lift (StateT (GenState g ) (MonadReader (Reader GenEnv ))))
112
115
113
- instance RandomGen g => StatefulGen (CapGenM g ) (M g ) where
114
- uniformWord64 _ = state @ " randomSeed" genWord64
115
- uniformWord32 _ = state @ " randomSeed" genWord32
116
-
117
- uniformShortByteString n _ = state @ " randomSeed" (genShortByteString n)
118
-
119
- instance RandomGen r => RandomGenM (CapGenM r ) r (M r ) where
120
- applyRandomGenM f _ = state @ " randomSeed" f
121
-
122
- runGen :: M g a -> GenEnv g -> GenState g -> (a , GenState g )
116
+ runGen :: M g a -> GenEnv -> GenState g -> (a , GenState g )
123
117
runGen m env st = runReader (runStateT (runM m) st) env
124
118
125
- evalGen :: M g a -> GenEnv g -> GenState g -> a
119
+ evalGen :: M g a -> GenEnv -> GenState g -> a
126
120
evalGen m env = fst . runGen m env
127
121
128
- asksM :: forall tag r m a . HasReader tag r m => (r -> m a ) -> m a
129
- asksM f = f =<< ask @ tag
130
-
131
122
--------------------------------------------------------------------------------
132
123
-- Wrappers around some Random function in Gen
133
124
--------------------------------------------------------------------------------
134
125
135
126
genUniformRM :: forall a g . (UniformRange a , RandomGen g ) => (a , a ) -> M g a
136
- genUniformRM = asksM @ " fakeSeed " . uniformRM
127
+ genUniformRM r = uniformRM r ( StateGenM @ ( GenState g ))
137
128
138
129
-- | Generate a random number in a given range, biased increasingly towards the
139
130
-- lower end as the depth parameter increases.
@@ -143,9 +134,8 @@ genDepthBiasedRM ::
143
134
(a , a ) ->
144
135
M g a
145
136
genDepthBiasedRM bounds = do
146
- fs <- ask @ " fakeSeed"
147
137
d <- get @ " depth"
148
- samples <- replicateM d (uniformRM bounds fs )
138
+ samples <- replicateM d (genUniformRM bounds)
149
139
pure $ minimum samples
150
140
151
141
-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
@@ -155,10 +145,10 @@ genDepthBiasedBool = do
155
145
and <$> replicateM d genRandomM
156
146
157
147
genRandomM :: forall g a . (Random a , RandomGen g ) => M g a
158
- genRandomM = asksM @ " fakeSeed " randomM
148
+ genRandomM = randomM ( StateGenM @ ( GenState g ))
159
149
160
150
genBytes :: forall g . RandomGen g => Int -> M g ByteString
161
- genBytes n = asksM @ " fakeSeed " $ uniformByteStringM n
151
+ genBytes n = uniformByteStringM n ( StateGenM @ ( GenState g ))
162
152
163
153
genText :: forall g . RandomGen g => Int -> M g Text
164
154
genText n = pure $ T. pack . take n . join $ repeat [' a' .. ' z' ]
@@ -436,12 +426,12 @@ genValueVariant (VBool b) = pure $ TBool b
436
426
437
427
generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
438
428
generateCBORTerm cddl n stdGen =
439
- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
429
+ let genEnv = GenEnv {cddl}
440
430
genState = GenState {randomSeed = stdGen, depth = 1 }
441
431
in evalGen (genForName n) genEnv genState
442
432
443
433
generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term , g )
444
434
generateCBORTerm' cddl n stdGen =
445
- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
435
+ let genEnv = GenEnv {cddl}
446
436
genState = GenState {randomSeed = stdGen, depth = 1 }
447
437
in second randomSeed $ runGen (genForName n) genEnv genState
0 commit comments