Skip to content

Commit 08dfb00

Browse files
committed
Simplify usage of random
1 parent d704372 commit 08dfb00

File tree

1 file changed

+34
-44
lines changed
  • src/Codec/CBOR/Cuddle/CBOR

1 file changed

+34
-44
lines changed

src/Codec/CBOR/Cuddle/CBOR/Gen.hs

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
{-# LANGUAGE AllowAmbiguousTypes #-}
2+
{-# LANGUAGE CPP #-}
23
{-# LANGUAGE DataKinds #-}
34
{-# LANGUAGE DerivingVia #-}
45
{-# LANGUAGE GADTs #-}
56
{-# LANGUAGE LambdaCase #-}
67
{-# LANGUAGE OverloadedStrings #-}
78
{-# LANGUAGE PatternSynonyms #-}
9+
{-# LANGUAGE ScopedTypeVariables #-}
810
{-# LANGUAGE ViewPatterns #-}
911

1012
-- | Generate example CBOR given a CDDL specification
1113
module Codec.CBOR.Cuddle.CBOR.Gen (generateCBORTerm, generateCBORTerm') where
1214

15+
import qualified Control.Monad.State.Strict as MTL
1316
import Capability.Reader
1417
import Capability.Sink (HasSink)
1518
import Capability.Source (HasSource, MonadState (..))
16-
import Capability.State (HasState, get, modify, state)
19+
import Capability.State (HasState, get, modify)
1720
import Codec.CBOR.Cuddle.CDDL (
1821
Name (..),
1922
OccurrenceIndicator (..),
@@ -45,11 +48,9 @@ import Data.Word (Word32, Word64)
4548
import GHC.Generics (Generic)
4649
import System.Random.Stateful (
4750
Random,
48-
RandomGen (genShortByteString, genWord32, genWord64),
49-
RandomGenM,
50-
StatefulGen (..),
51+
RandomGen (..),
52+
StateGenM (..),
5153
UniformRange (uniformRM),
52-
applyRandomGenM,
5354
randomM,
5455
uniformByteStringM,
5556
)
@@ -59,10 +60,8 @@ import System.Random.Stateful (
5960
--------------------------------------------------------------------------------
6061

6162
-- | Generator context, parametrised over the type of the random seed
62-
data GenEnv g = GenEnv
63+
newtype GenEnv = GenEnv
6364
{ cddl :: CTreeRoot' Identity MonoRef
64-
, fakeSeed :: CapGenM g
65-
-- ^ Access the "fake" seed, necessary to recursively call generators
6665
}
6766
deriving (Generic)
6867

@@ -77,63 +76,55 @@ data GenState g = GenState
7776
}
7877
deriving (Generic)
7978

80-
newtype M g a = M {runM :: StateT (GenState g) (Reader (GenEnv g)) a}
81-
deriving (Functor, Applicative, Monad)
79+
instance RandomGen g => RandomGen (GenState g) where
80+
genWord8 = withRandomSeed genWord8
81+
genWord16 = withRandomSeed genWord16
82+
genWord32 = withRandomSeed genWord32
83+
genWord64 = withRandomSeed genWord64
84+
split s =
85+
case split (randomSeed s) of
86+
(gen', gen) -> (s {randomSeed = gen'}, s {randomSeed = gen})
87+
88+
withRandomSeed :: (t -> (a, g)) -> GenState t -> (a, GenState g)
89+
withRandomSeed f s =
90+
case f (randomSeed s) of
91+
(r, gen) -> (r, s {randomSeed = gen})
92+
93+
newtype M g a = M {runM :: StateT (GenState g) (Reader GenEnv) a}
94+
deriving (Functor, Applicative, Monad, MTL.MonadState (GenState g))
8295
deriving
8396
(HasSource "randomSeed" g, HasSink "randomSeed" g, HasState "randomSeed" g)
8497
via Field
8598
"randomSeed"
8699
()
87-
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
100+
(MonadState (StateT (GenState g) (Reader GenEnv)))
88101
deriving
89102
(HasSource "depth" Int, HasSink "depth" Int, HasState "depth" Int)
90103
via Field
91104
"depth"
92105
()
93-
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
106+
(MonadState (StateT (GenState g) (Reader GenEnv)))
94107
deriving
95108
( HasSource "cddl" (CTreeRoot' Identity MonoRef)
96109
, HasReader "cddl" (CTreeRoot' Identity MonoRef)
97110
)
98111
via Field
99112
"cddl"
100113
()
101-
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
102-
deriving
103-
(HasSource "fakeSeed" (CapGenM g), HasReader "fakeSeed" (CapGenM g))
104-
via Field
105-
"fakeSeed"
106-
()
107-
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
108-
109-
-- | Opaque type carrying the type of a pure PRNG inside a capability-style
110-
-- state monad.
111-
data CapGenM g = CapGenM
114+
(Lift (StateT (GenState g) (MonadReader (Reader GenEnv))))
112115

113-
instance RandomGen g => StatefulGen (CapGenM g) (M g) where
114-
uniformWord64 _ = state @"randomSeed" genWord64
115-
uniformWord32 _ = state @"randomSeed" genWord32
116-
117-
uniformShortByteString n _ = state @"randomSeed" (genShortByteString n)
118-
119-
instance RandomGen r => RandomGenM (CapGenM r) r (M r) where
120-
applyRandomGenM f _ = state @"randomSeed" f
121-
122-
runGen :: M g a -> GenEnv g -> GenState g -> (a, GenState g)
116+
runGen :: M g a -> GenEnv -> GenState g -> (a, GenState g)
123117
runGen m env st = runReader (runStateT (runM m) st) env
124118

125-
evalGen :: M g a -> GenEnv g -> GenState g -> a
119+
evalGen :: M g a -> GenEnv -> GenState g -> a
126120
evalGen m env = fst . runGen m env
127121

128-
asksM :: forall tag r m a. HasReader tag r m => (r -> m a) -> m a
129-
asksM f = f =<< ask @tag
130-
131122
--------------------------------------------------------------------------------
132123
-- Wrappers around some Random function in Gen
133124
--------------------------------------------------------------------------------
134125

135126
genUniformRM :: forall a g. (UniformRange a, RandomGen g) => (a, a) -> M g a
136-
genUniformRM = asksM @"fakeSeed" . uniformRM
127+
genUniformRM r = uniformRM r (StateGenM @(GenState g))
137128

138129
-- | Generate a random number in a given range, biased increasingly towards the
139130
-- lower end as the depth parameter increases.
@@ -143,9 +134,8 @@ genDepthBiasedRM ::
143134
(a, a) ->
144135
M g a
145136
genDepthBiasedRM bounds = do
146-
fs <- ask @"fakeSeed"
147137
d <- get @"depth"
148-
samples <- replicateM d (uniformRM bounds fs)
138+
samples <- replicateM d (genUniformRM bounds)
149139
pure $ minimum samples
150140

151141
-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
@@ -155,10 +145,10 @@ genDepthBiasedBool = do
155145
and <$> replicateM d genRandomM
156146

157147
genRandomM :: forall g a. (Random a, RandomGen g) => M g a
158-
genRandomM = asksM @"fakeSeed" randomM
148+
genRandomM = randomM (StateGenM @(GenState g))
159149

160150
genBytes :: forall g. RandomGen g => Int -> M g ByteString
161-
genBytes n = asksM @"fakeSeed" $ uniformByteStringM n
151+
genBytes n = uniformByteStringM n (StateGenM @(GenState g))
162152

163153
genText :: forall g. RandomGen g => Int -> M g Text
164154
genText n = pure $ T.pack . take n . join $ repeat ['a' .. 'z']
@@ -436,12 +426,12 @@ genValueVariant (VBool b) = pure $ TBool b
436426

437427
generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
438428
generateCBORTerm cddl n stdGen =
439-
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
429+
let genEnv = GenEnv {cddl}
440430
genState = GenState {randomSeed = stdGen, depth = 1}
441431
in evalGen (genForName n) genEnv genState
442432

443433
generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term, g)
444434
generateCBORTerm' cddl n stdGen =
445-
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
435+
let genEnv = GenEnv {cddl}
446436
genState = GenState {randomSeed = stdGen, depth = 1}
447437
in second randomSeed $ runGen (genForName n) genEnv genState

0 commit comments

Comments
 (0)