Skip to content

Commit 4534b6e

Browse files
committed
io-sim-por: use ST monad rather than unsafePerformIO
1 parent 6d62b8b commit 4534b6e

File tree

2 files changed

+146
-108
lines changed

2 files changed

+146
-108
lines changed

io-sim/src/Control/Monad/IOSim.hs

Lines changed: 131 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
{-# LANGUAGE BangPatterns #-}
12
{-# LANGUAGE ExplicitNamespaces #-}
23
{-# LANGUAGE NamedFieldPuns #-}
34
{-# LANGUAGE RankNTypes #-}
45
{-# LANGUAGE ScopedTypeVariables #-}
56
{-# LANGUAGE TupleSections #-}
67

78
{-# OPTIONS_GHC -Wno-name-shadowing #-}
9+
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
10+
811
module Control.Monad.IOSim
912
( -- * Simulation monad
1013
IOSim
@@ -88,6 +91,7 @@ import Prelude
8891
import Data.Bifoldable
8992
import Data.Dynamic (fromDynamic)
9093
import Data.List (intercalate)
94+
import Data.Maybe (catMaybes)
9195
import Data.Set (Set)
9296
import qualified Data.Set as Set
9397
import Data.Typeable (Typeable)
@@ -97,6 +101,7 @@ import Data.List.Trace (Trace (..))
97101
import Control.Exception (throw)
98102

99103
import Control.Monad.ST.Lazy
104+
import Data.STRef.Lazy
100105

101106
import Control.Monad.Class.MonadThrow as MonadThrow
102107

@@ -108,7 +113,6 @@ import Control.Monad.IOSimPOR.QuickCheckUtils
108113
import Test.QuickCheck
109114

110115

111-
import Data.IORef
112116
import System.IO.Unsafe
113117

114118

@@ -161,23 +165,23 @@ selectTraceRaces = go
161165
-- unsafe, of course, since that function may return different results
162166
-- at different times.
163167

164-
detachTraceRaces :: forall a. SimTrace a -> (() -> [ScheduleControl], SimTrace a)
165-
detachTraceRaces trace = unsafePerformIO $ do
166-
races <- newIORef []
167-
let readRaces :: () -> [ScheduleControl]
168-
readRaces () = concat . reverse . unsafePerformIO $ readIORef races
168+
detachTraceRacesST :: forall a s. SimTrace a -> ST s (ST s [ScheduleControl], SimTrace a)
169+
detachTraceRacesST trace0 = do
170+
races <- newSTRef []
171+
let readRaces :: ST s [ScheduleControl]
172+
readRaces = concat . reverse <$> readSTRef races
169173

170-
saveRaces :: [ScheduleControl] -> x -> x
171-
saveRaces rs t = unsafePerformIO $ modifyIORef races (rs:)
172-
>> return t
174+
saveRaces :: [ScheduleControl] -> ST s ()
175+
saveRaces rs = modifySTRef races (rs:)
173176

174-
go :: SimTrace a -> SimTrace a
175-
go (SimTrace a b c d trace) = SimTrace a b c d $ go trace
176-
go (SimPORTrace a b c d e trace) = SimPORTrace a b c d e $ go trace
177-
go (TraceRacesFound rs trace) = saveRaces rs $ go trace
178-
go t = t
177+
go :: SimTrace a -> ST s (SimTrace a)
178+
go (SimTrace a b c d trace) = SimTrace a b c d <$> go trace
179+
go (SimPORTrace a b c d e trace) = SimPORTrace a b c d e <$> go trace
180+
go (TraceRacesFound rs trace) = saveRaces rs >> go trace
181+
go t = return t
179182

180-
return (readRaces, go trace)
183+
trace <- go trace0
184+
return (readRaces, trace)
181185

182186
-- | Select all the traced values matching the expected type. This relies on
183187
-- the sim's dynamic trace facility.
@@ -472,53 +476,59 @@ exploreSimTrace
472476
exploreSimTrace optsf mainAction k =
473477
case explorationReplay opts of
474478
Nothing ->
475-
explore (explorationScheduleBound opts) (explorationBranching opts) ControlDefault Nothing .&&.
476-
let size = cacheSize() in size `seq`
477-
tabulate "Modified schedules explored" [bucket size] True
479+
case runST (do cacheRef <- createCacheST
480+
prop <- explore cacheRef (explorationScheduleBound opts) (explorationBranching opts) ControlDefault Nothing
481+
size <- cacheSizeST cacheRef
482+
return (prop, size)
483+
) of
484+
(prop, !size) -> tabulate "Modified schedules explored" [bucket size] prop
485+
478486
Just control ->
479487
replaySimTrace opts mainAction control (k Nothing)
488+
480489
where
481490
opts = optsf stdExplorationOptions
482491

483-
explore :: Int -- schedule bound
492+
explore :: forall s.
493+
STRef s (Set ScheduleControl)
494+
-> Int -- schedule bound
484495
-> Int -- branching factor
485-
-> ScheduleControl -> Maybe (SimTrace a) -> Property
486-
explore n m control passingTrace =
487-
488-
-- ALERT!!! Impure code: readRaces must be called *after* we have
489-
-- finished with trace.
490-
let (readRaces, trace0) = detachTraceRaces $
491-
controlSimTrace
492-
(explorationStepTimelimit opts) control mainAction
493-
(sleeper,trace) = compareTraces passingTrace trace0
494-
in ( counterexample ("Schedule control: " ++ show control)
495-
$ counterexample
496-
(case sleeper of
497-
Nothing -> "No thread delayed"
498-
Just ((t,tid,lab),racing) ->
499-
showThread (tid,lab) ++
500-
" delayed at time "++
501-
show t ++
502-
"\n until after:\n" ++
503-
unlines (map ((" "++).showThread) $ Set.toList racing)
504-
)
505-
$ k passingTrace trace
506-
)
507-
.&&| let limit = (n+m-1) `div` m
508-
-- To ensure the set of schedules explored is deterministic, we
509-
-- filter out cached ones *after* selecting the children of this
510-
-- node.
511-
races = filter (not . cached) . take limit $ readRaces ()
512-
branching = length races
513-
in -- tabulate "Races explored" (map show races) $
514-
tabulate "Branching factor" [bucket branching] $
515-
tabulate "Race reversals per schedule" [bucket (raceReversals control)] $
516-
conjoinPar
517-
[ --Debug.trace "New schedule:" $
518-
--Debug.trace (" "++show r) $
519-
--counterexample ("Schedule control: " ++ show r) $
520-
explore n' ((m-1) `max` 1) r (Just trace0)
521-
| (r,n') <- zip races (divide (n-branching) branching) ]
496+
-> ScheduleControl -> Maybe (SimTrace a) -> ST s Property
497+
explore cacheRef n m control passingTrace = do
498+
traceWithRaces <- controlSimTraceST (explorationStepTimelimit opts) control mainAction
499+
(readRaces, trace0) <- detachTraceRacesST traceWithRaces
500+
(readSleeperST, trace) <- compareTracesST passingTrace trace0
501+
conjoinNoCatchST
502+
[ do sleeper <- readSleeperST
503+
return $ counterexample ("Schedule control: " ++ show control)
504+
$ counterexample
505+
(case sleeper of
506+
Nothing -> "No thread delayed"
507+
Just ((t,tid,lab),racing) ->
508+
showThread (tid,lab) ++
509+
" delayed at time "++
510+
show t ++
511+
"\n until after:\n" ++
512+
unlines (map ((" "++).showThread) $ Set.toList racing)
513+
)
514+
$ k passingTrace trace
515+
, do let limit = (n+m-1) `div` m
516+
-- To ensure the set of schedules explored is deterministic, we
517+
-- filter out cached ones *after* selecting the children of this
518+
-- node.
519+
races <- catMaybes
520+
<$> (readRaces >>= traverse (cachedST cacheRef) . take limit)
521+
let branching = length races
522+
-- tabulate "Races explored" (map show races) $
523+
tabulate "Branching factor" [bucket branching]
524+
. tabulate "Race reversals per schedule" [bucket (raceReversals control)]
525+
<$> conjoinParST
526+
[ --Debug.trace "New schedule:" $
527+
--Debug.trace (" "++show r) $
528+
--counterexample ("Schedule control: " ++ show r) $
529+
explore cacheRef n' ((m-1) `max` 1) r (Just trace0)
530+
| (r,n') <- zip races (divide (n-branching) branching) ]
531+
]
522532

523533
bucket :: Int -> String
524534
bucket n | n<10 = show n
@@ -537,39 +547,33 @@ exploreSimTrace optsf mainAction k =
537547
show tid ++ (case lab of Nothing -> ""
538548
Just l -> " ("++l++")")
539549

540-
-- cache of explored schedules
541-
cache :: IORef (Set ScheduleControl)
542-
cache = unsafePerformIO cacheIO
543-
544550
-- insert a schedule into the cache
545-
cached :: ScheduleControl -> Bool
546-
cached = unsafePerformIO . cachedIO
547-
548-
-- compute cache size; it's a function to make sure that `GHC` does not
549-
-- inline it (and share the same thunk).
550-
cacheSize :: () -> Int
551-
cacheSize = unsafePerformIO . cacheSizeIO
551+
cachedST :: STRef s (Set ScheduleControl) -> ScheduleControl -> ST s (Maybe ScheduleControl)
552+
cachedST cacheRef a = do
553+
set <- readSTRef cacheRef
554+
writeSTRef cacheRef (Set.insert a set)
555+
return $ if Set.member a set
556+
then Nothing
557+
else Just a
552558

553559
--
554-
-- Caching in IO monad
560+
-- Caching in ST monad
555561
--
556562

563+
-- TODO: Use STRef!
564+
557565
-- It is possible for the same control to be generated several times.
558566
-- To avoid exploring them twice, we keep a cache of explored schedules.
559-
cacheIO :: IO (IORef (Set ScheduleControl))
560-
cacheIO = newIORef $
567+
createCacheST :: ST s (STRef s (Set ScheduleControl))
568+
createCacheST = newSTRef $
561569
-- we use opts here just to be sure the reference cannot be
562570
-- lifted out of exploreSimTrace
563571
if explorationScheduleBound opts >=0
564572
then Set.empty
565573
else error "exploreSimTrace: negative schedule bound"
566574

567-
cachedIO :: ScheduleControl -> IO Bool
568-
cachedIO m = atomicModifyIORef' cache $ \set ->
569-
(Set.insert m set, Set.member m set)
570-
571-
cacheSizeIO :: () -> IO Int
572-
cacheSizeIO () = Set.size <$> readIORef cache
575+
cacheSizeST :: STRef s (Set ScheduleControl) -> ST s Int
576+
cacheSizeST = fmap Set.size . readSTRef
573577

574578

575579
-- | A specialised version of `controlSimTrace'.
@@ -587,8 +591,8 @@ replaySimTrace :: forall a test. (Testable test)
587591
-- will not contain any race events
588592
-> Property
589593
replaySimTrace opts mainAction control k =
590-
let (_,trace) = detachTraceRaces $
591-
controlSimTrace (explorationStepTimelimit opts) control mainAction
594+
let trace = runST $ fmap snd $ detachTraceRacesST =<<
595+
controlSimTraceST (explorationStepTimelimit opts) control mainAction
592596
in property (k trace)
593597

594598
-- | Run a simulation using a given schedule. This is useful to reproduce
@@ -623,38 +627,58 @@ raceReversals ControlFollow{} = error "Impossible: raceReversals ControlFoll
623627
-- this far, then we collect its identity only if it is reached using
624628
-- unsafePerformIO.
625629

626-
compareTraces :: Maybe (SimTrace a1)
627-
-> SimTrace a2
628-
-> (Maybe ((Time, ThreadId, Maybe ThreadLabel),
629-
Set.Set (ThreadId, Maybe ThreadLabel)),
630-
SimTrace a2)
631-
compareTraces Nothing trace = (Nothing, trace)
632-
compareTraces (Just passing) trace = unsafePerformIO $ do
633-
sleeper <- newIORef Nothing
634-
return (unsafePerformIO $ readIORef sleeper,
635-
go sleeper passing trace)
636-
where go sleeper (SimPORTrace tpass tidpass _ _ _ pass')
630+
compareTracesST :: forall a b s.
631+
Maybe (SimTrace a) -- ^ passing
632+
-> SimTrace b -- ^ failing
633+
-> ST s ( ST s (Maybe ( (Time, ThreadId, Maybe ThreadLabel)
634+
, Set.Set (ThreadId, Maybe ThreadLabel)
635+
))
636+
, SimTrace b
637+
)
638+
compareTracesST Nothing trace = return (return Nothing, trace)
639+
compareTracesST (Just passing) trace = do
640+
sleeper <- newSTRef Nothing
641+
trace' <- go sleeper passing trace
642+
return ( readSTRef sleeper
643+
, trace'
644+
)
645+
where
646+
go :: STRef s (Maybe ( (Time, ThreadId, Maybe ThreadLabel)
647+
, Set.Set (ThreadId, Maybe ThreadLabel)
648+
))
649+
-> SimTrace a -- ^ passing
650+
-> SimTrace b -- ^ failing
651+
-> ST s (SimTrace b)
652+
go sleeper (SimPORTrace tpass tidpass _ _ _ pass')
637653
(SimPORTrace tfail tidfail tstepfail tlfail evfail fail')
638654
| (tpass,tidpass) == (tfail,tidfail) =
639655
SimPORTrace tfail tidfail tstepfail tlfail evfail
640-
$ go sleeper pass' fail'
641-
go sleeper (SimPORTrace tpass tidpass tsteppass tlpass _ _) fail =
642-
unsafePerformIO $ do
643-
writeIORef sleeper $ Just ((tpass, tidpass, tlpass),Set.empty)
644-
return $ SimPORTrace tpass tidpass tsteppass tlpass EventThreadSleep
645-
$ wakeup sleeper tidpass fail
646-
go _ SimTrace {} _ = error "compareTraces: invariant violation"
647-
go _ _ SimTrace {} = error "compareTraces: invariant violation"
648-
go _ _ fail = fail
649-
656+
<$> go sleeper pass' fail'
657+
go sleeper (SimPORTrace tpass tidpass tsteppass tlpass _ _) fail = do
658+
writeSTRef sleeper $ Just ((tpass, tidpass, tlpass),Set.empty)
659+
SimPORTrace tpass tidpass tsteppass tlpass EventThreadSleep
660+
<$> wakeup sleeper tidpass fail
661+
go _ SimTrace {} _ = error "compareTracesST: invariant violation"
662+
go _ _ SimTrace {} = error "compareTracesST: invariant violation"
663+
go _ _ fail = return fail
664+
665+
wakeup :: STRef s (Maybe ( (Time, ThreadId, Maybe ThreadLabel)
666+
, Set.Set (ThreadId, Maybe ThreadLabel)
667+
))
668+
-> ThreadId
669+
-> SimTrace b
670+
-> ST s (SimTrace b)
650671
wakeup sleeper tidpass
651672
fail@(SimPORTrace tfail tidfail tstepfail tlfail evfail fail')
652673
| tidpass == tidfail =
653-
SimPORTrace tfail tidfail tstepfail tlfail EventThreadWake fail
654-
| otherwise = unsafePerformIO $ do
655-
Just (slp,racing) <- readIORef sleeper
656-
writeIORef sleeper $ Just (slp,Set.insert (tidfail,tlfail) racing)
657-
return $ SimPORTrace tfail tidfail tstepfail tlfail evfail
658-
$ wakeup sleeper tidpass fail'
659-
wakeup _ _ SimTrace {} = error "compareTraces: invariant violation"
660-
wakeup _ _ fail = fail
674+
return $ SimPORTrace tfail tidfail tstepfail tlfail EventThreadWake fail
675+
| otherwise = do
676+
ms <- readSTRef sleeper
677+
case ms of
678+
Just (slp, racing) -> do
679+
writeSTRef sleeper $ Just (slp,Set.insert (tidfail,tlfail) racing)
680+
SimPORTrace tfail tidfail tstepfail tlfail evfail
681+
<$> wakeup sleeper tidpass fail'
682+
Nothing -> error "compareTraceST: invariant violation"
683+
wakeup _ _ SimTrace {} = error "compareTracesST: invariant violation"
684+
wakeup _ _ fail = return fail

io-sim/src/Control/Monad/IOSimPOR/QuickCheckUtils.hs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,17 @@
55
module Control.Monad.IOSimPOR.QuickCheckUtils where
66

77
import Control.Parallel
8+
import Control.Monad.ST.Lazy
9+
import Control.Monad.ST.Lazy.Unsafe (unsafeInterleaveST)
810
import Test.QuickCheck.Gen
911
import Test.QuickCheck.Property
1012

13+
-- note: this only evaluates `prop` in parallel, not `ST` actions
14+
conjoinParST :: TestableNoCatch prop => [ST s prop] -> ST s Property
15+
conjoinParST sts = do
16+
ps <- sequence sts
17+
return $ conjoinPar ps
18+
1119
-- Take the conjunction of several properties, in parallel This is a
1220
-- modification of code from Test.QuickCheck.Property, to run non-IO
1321
-- properties in parallel. It also takes care NOT to label its result
@@ -30,6 +38,12 @@ conjoinPar = conjoinSpeculate speculate
3038
-- We also need a version of conjoin that is sequential, but does not
3139
-- label its result as an IO property unless one of its arguments
3240
-- is. Consequently it does not catch exceptions in its arguments.
41+
42+
conjoinNoCatchST :: TestableNoCatch prop => [ST s prop] -> ST s Property
43+
conjoinNoCatchST sts = do
44+
ps <- sequence sts
45+
return $ conjoinNoCatch ps
46+
3347
conjoinNoCatch :: TestableNoCatch prop => [prop] -> Property
3448
conjoinNoCatch = conjoinSpeculate id
3549

@@ -107,7 +121,7 @@ instance TestableNoCatch Result where
107121
propertyNoCatch = MkProperty . return . MkProp . return
108122

109123
instance TestableNoCatch Prop where
110-
propertyNoCatch p = MkProperty . return $ p
124+
propertyNoCatch = MkProperty . return
111125

112126
instance TestableNoCatch prop => TestableNoCatch (Gen prop) where
113127
propertyNoCatch mp = MkProperty $ do p <- mp; unProperty (againNoCatch $ propertyNoCatch p)

0 commit comments

Comments
 (0)