Skip to content

Commit abab0be

Browse files
committed
io-sim: MonadFix instance using unsafeInterleaveIO
1 parent 540d50a commit abab0be

File tree

4 files changed

+189
-4
lines changed

4 files changed

+189
-4
lines changed

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ import qualified Data.Set as Set
6666
import Data.Time (UTCTime (..), fromGregorian)
6767
import Data.Dynamic
6868

69-
import Control.Exception (assert)
69+
import Control.Exception (NonTermination (..),
70+
assert, throw)
7071
import Control.Monad (join)
7172

7273
import Control.Monad (when)
7374
import Control.Monad.ST.Lazy
74-
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST)
75+
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST)
7576
import Data.STRef.Lazy
7677

7778
import Control.Monad.Class.MonadSTM hiding (STM, TVar)
@@ -527,6 +528,14 @@ schedule thread@Thread{
527528
-- ExploreRaces is ignored by this simulator
528529
ExploreRaces k -> schedule thread{ threadControl = ThreadControl k ctl } simstate
529530

531+
Fix f k -> do
532+
r <- newSTRef (throw NonTermination)
533+
x <- unsafeInterleaveST $ readSTRef r
534+
let k' = unIOSim (f x) $ \x' ->
535+
LiftST (lazyToStrictST (writeSTRef r x')) (\() -> k x')
536+
thread' = thread { threadControl = ThreadControl k' ctl }
537+
schedule thread' simstate
538+
530539

531540
threadInterruptible :: Thread s a -> Bool
532541
threadInterruptible thread =

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ module Control.Monad.IOSim.Types
7070
import Control.Exception (ErrorCall (..), asyncExceptionFromException, asyncExceptionToException)
7171
import Control.Applicative
7272
import Control.Monad
73+
import Control.Monad.Fix (MonadFix (..))
7374

7475
import Control.Monad.Class.MonadAsync hiding (Async)
7576
import qualified Control.Monad.Class.MonadAsync as MonadAsync
@@ -163,6 +164,8 @@ data SimA s a where
163164

164165
ExploreRaces :: SimA s b -> SimA s b
165166

167+
Fix :: (x -> IOSim s x) -> (x -> SimA s r) -> SimA s r
168+
166169

167170
newtype STM s a = STM { unSTM :: forall r. (a -> StmA s r) -> StmA s r }
168171

@@ -238,6 +241,9 @@ instance Monoid a => Monoid (IOSim s a) where
238241
instance Fail.MonadFail (IOSim s) where
239242
fail msg = IOSim $ \_ -> Throw (toException (IO.Error.userError msg))
240243

244+
instance MonadFix (IOSim s) where
245+
mfix f = IOSim $ \k -> Fix f k
246+
241247

242248
instance Functor (STM s) where
243249
{-# INLINE fmap #-}

io-sim/src/Control/Monad/IOSimPOR/Internal.hs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ import Data.Set (Set)
6969
import qualified Data.Set as Set
7070
import Data.Time (UTCTime (..), fromGregorian)
7171

72-
import Control.Exception (assert)
72+
import Control.Exception (NonTermination (..), assert, throw)
7373
import Control.Monad (join)
7474

7575
import Control.Monad (when)
7676
import Control.Monad.ST.Lazy
77-
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST)
77+
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST)
7878
import Data.STRef.Lazy
7979

8080
import Control.Monad.Class.MonadSTM hiding (STM, TVar)
@@ -608,6 +608,14 @@ schedule thread@Thread{
608608
, threadRacy = True }
609609
schedule thread' simstate
610610

611+
Fix f k -> do
612+
r <- newSTRef (throw NonTermination)
613+
x <- unsafeInterleaveST $ readSTRef r
614+
let k' = unIOSim (f x) $ \x' ->
615+
LiftST (lazyToStrictST (writeSTRef r x')) (\() -> k x')
616+
thread' = thread { threadControl = ThreadControl k' ctl }
617+
schedule thread' simstate
618+
611619
GetMaskState k -> do
612620
let thread' = thread { threadControl = ThreadControl (k maskst) ctl }
613621
schedule thread' simstate

io-sim/test/Test/IOSim.hs

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE CPP #-}
12
{-# LANGUAGE FlexibleContexts #-}
23
{-# LANGUAGE RankNTypes #-}
34
{-# LANGUAGE ScopedTypeVariables #-}
@@ -11,20 +12,25 @@ module Test.IOSim
1112
import Data.Array
1213
import Data.Either (isLeft)
1314
import Data.Fixed (Fixed (..), Micro)
15+
import Data.Foldable (foldl')
1416
import Data.Function (on)
17+
import Data.Functor (($>))
1518
import Data.Graph
1619
import Data.List (sortBy)
1720
import Data.Time.Clock (picosecondsToDiffTime)
1821

1922
import Control.Exception (ArithException (..))
2023
import Control.Monad
24+
import Control.Monad.Fix
2125
import System.IO.Error (ioeGetErrorString, isUserError)
2226

2327
import Control.Monad.Class.MonadFork
2428
import Control.Monad.Class.MonadSTM.Strict
2529
import Control.Monad.Class.MonadSay
30+
import qualified Control.Monad.Class.MonadSTM as LazySTM
2631
import Control.Monad.Class.MonadThrow
2732
import Control.Monad.Class.MonadTimer
33+
import Control.Monad.Class.MonadTime
2834
import Control.Monad.IOSim
2935

3036
import Test.STM
@@ -131,6 +137,13 @@ tests =
131137
[ testProperty "Reference vs IO" prop_stm_referenceIO
132138
, testProperty "Reference vs Sim" prop_stm_referenceSim
133139
]
140+
, testGroup "MonadFix instance"
141+
[ testProperty "purity" prop_mfix_purity
142+
, testProperty "purity2" prop_mfix_purity_2
143+
, testProperty "tightening" prop_mfix_left_shrinking
144+
, testProperty "lazy" prop_mfix_lazy
145+
, testProperty "recdata" prop_mfix_recdata
146+
]
134147
]
135148

136149

@@ -401,6 +414,155 @@ test_wakeup_order = do
401414
return (wakupOrder === [0..9]) --FIFO order
402415

403416

417+
--
418+
-- MonadFix properties
419+
--
420+
421+
-- | Purity demands that @mfix (return . f) = return (fix f)@.
422+
--
423+
prop_mfix_purity :: Positive Int -> Bool
424+
prop_mfix_purity (Positive n) =
425+
runSimOrThrow
426+
(mfix (return . factorial)) n
427+
== fix factorial n
428+
where
429+
factorial :: (Int -> Int) -> Int -> Int
430+
factorial = \rec_ k -> if k <= 1 then 1 else k * rec_ (k - 1)
431+
432+
433+
prop_mfix_purity_2 :: [Positive Int] -> Bool
434+
prop_mfix_purity_2 as =
435+
-- note: both 'IOSim' expressions are equivalent using 'Monad' and
436+
-- 'Applicative' laws only.
437+
runSimOrThrow (join $ mfix (return . recDelay)
438+
<*> return as')
439+
== expected
440+
&&
441+
runSimOrThrow (mfix (return . recDelay) >>= ($ as'))
442+
== expected
443+
where
444+
as' :: [Int]
445+
as' = getPositive `map` as
446+
447+
-- recursive sum using 'threadDelay'
448+
recDelay :: ( MonadMonotonicTime m
449+
, MonadDelay m
450+
)
451+
=> ([Int] -> m Time)
452+
-> [Int] -> m Time
453+
recDelay = \rec_ bs ->
454+
case bs of
455+
[] -> getMonotonicTime
456+
(b : bs') -> threadDelay (realToFrac b)
457+
>> rec_ bs'
458+
459+
expected :: Time
460+
expected = foldl' (flip addTime)
461+
(Time 0)
462+
(realToFrac `map` as')
463+
464+
465+
prop_mfix_left_shrinking
466+
:: Int
467+
-> NonNegative Int
468+
-> Positive Int
469+
-> Bool
470+
prop_mfix_left_shrinking n (NonNegative d) (Positive i) =
471+
let mn :: IOSim s Int
472+
mn = do say ""
473+
threadDelay (realToFrac d)
474+
return n
475+
in
476+
take i
477+
(runSimOrThrow $
478+
mfix (\rec_ -> mn >>= \a -> do
479+
threadDelay (realToFrac d) $> a : rec_))
480+
==
481+
take i
482+
(runSimOrThrow $
483+
mn >>= \a ->
484+
(mfix (\rec_ -> do
485+
threadDelay (realToFrac d) $> a : rec_)))
486+
487+
488+
489+
-- | 'Example 8.2.1' in 'Value Recursion in Monadic Computations'
490+
-- <https://leventerkok.github.io/papers/erkok-thesis.pdf>
491+
--
492+
prop_mfix_lazy :: NonEmptyList Char
493+
-> Bool
494+
prop_mfix_lazy (NonEmpty env) =
495+
take samples
496+
(runSimOrThrow (withEnv (mfix . replicateHeadM)))
497+
== replicate samples (head env)
498+
where
499+
samples :: Int
500+
samples = 10
501+
502+
replicateHeadM ::
503+
(
504+
#if MIN_VERSION_base(4,13,0)
505+
MonadFail m,
506+
MonadFail (STM m),
507+
#endif
508+
MonadSTM m
509+
)
510+
=> m Char
511+
-> [Char] -> m [Char]
512+
replicateHeadM getChar_ as = do
513+
-- Note: 'getChar' will be executed only once! This follows from 'fixIO`
514+
-- semantics.
515+
a <- getChar_
516+
return (a : as)
517+
518+
-- construct 'getChar' using the simulated environment
519+
withEnv :: (
520+
#if MIN_VERSION_base(4,13,0)
521+
MonadFail m,
522+
#endif
523+
MonadSTM m
524+
)
525+
=> (m Char -> m a) -> m a
526+
withEnv k = do
527+
v <- newTVarIO env
528+
let getChar_ =
529+
atomically $ do
530+
as <- readTVar v
531+
case as of
532+
[] -> error "withEnv: runtime error"
533+
(a : as') -> writeTVar v as'
534+
$> a
535+
k getChar_
536+
537+
538+
-- | 'Example 8.2.3' in 'Value Recursion in Monadic Computations'
539+
-- <https://leventerkok.github.io/papers/erkok-thesis.pdf>
540+
--
541+
prop_mfix_recdata :: Property
542+
prop_mfix_recdata = ioProperty $ do
543+
expected <- experiment
544+
let res = runSimOrThrow experiment
545+
return $
546+
take samples res
547+
==
548+
take samples expected
549+
where
550+
samples :: Int
551+
samples = 10
552+
553+
experiment :: ( MonadSTM m
554+
, MonadFix m
555+
)
556+
=> m [Int]
557+
experiment = do
558+
(_, y) <-
559+
mfix (\ ~(x, _) -> do
560+
y <- LazySTM.newTVarIO x
561+
return (1:x, y)
562+
)
563+
atomically (LazySTM.readTVar y)
564+
565+
404566
--
405567
-- Probe mini-abstraction
406568
--

0 commit comments

Comments
 (0)