Skip to content

Commit 0c12aea

Browse files
iohk-bors[bot]coot
andauthored
Merge #3585
3585: Monad transformer instances for various type classes in `io-classes` r=coot a=coot Co-authored-by: Marcin Szamotulski <[email protected]>
2 parents 9acc5b7 + fc92450 commit 0c12aea

File tree

4 files changed

+538
-11
lines changed

4 files changed

+538
-11
lines changed

io-classes/src/Control/Monad/Class/MonadAsync.hs

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
{-# LANGUAGE DataKinds #-}
12
{-# LANGUAGE DefaultSignatures #-}
23
{-# LANGUAGE FlexibleContexts #-}
4+
{-# LANGUAGE GADTs #-}
35
{-# LANGUAGE MultiParamTypeClasses #-}
46
{-# LANGUAGE QuantifiedConstraints #-}
57
{-# LANGUAGE RankNTypes #-}
68
{-# LANGUAGE ScopedTypeVariables #-}
79
{-# LANGUAGE TypeApplications #-}
810
{-# LANGUAGE TypeFamilies #-}
911
{-# LANGUAGE TypeFamilyDependencies #-}
10-
12+
-- MonadAsync's ReaderT instance is undecidable.
13+
{-# LANGUAGE UndecidableInstances #-}
1114
module Control.Monad.Class.MonadAsync
1215
( MonadAsync (..)
1316
, AsyncCancelled (..)
@@ -34,10 +37,14 @@ import Control.Monad.Class.MonadSTM
3437
import Control.Monad.Class.MonadThrow
3538
import Control.Monad.Class.MonadTimer
3639

40+
import Control.Monad.Trans (lift)
41+
import Control.Monad.Reader (ReaderT (..))
42+
3743
import Control.Concurrent.Async (AsyncCancelled (..))
3844
import qualified Control.Concurrent.Async as Async
3945
import qualified Control.Exception as E
4046

47+
import Data.Bifunctor (first)
4148
import Data.Foldable (fold)
4249
import Data.Functor (void)
4350
import Data.Kind (Type)
@@ -390,3 +397,77 @@ forkRepeat label action =
390397

391398
tryAll :: MonadCatch m => m a -> m (Either SomeException a)
392399
tryAll = try
400+
401+
402+
--
403+
-- ReaderT instance
404+
--
405+
406+
newtype WrappedAsync r (m :: Type -> Type) a =
407+
WrappedAsync { unWrapAsync :: Async m a }
408+
409+
instance ( MonadAsync m
410+
, MonadCatch (STM m)
411+
) => MonadAsync (ReaderT r m) where
412+
type Async (ReaderT r m) = WrappedAsync r m
413+
asyncThreadId (WrappedAsync a) = asyncThreadId a
414+
415+
async (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> async (ma r)
416+
withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r)
417+
$ \a -> runReaderT (f (WrappedAsync a)) r
418+
asyncWithUnmask f = ReaderT $ \r -> fmap WrappedAsync
419+
$ asyncWithUnmask
420+
$ \unmask -> runReaderT (f (liftF unmask)) r
421+
where
422+
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
423+
liftF g (ReaderT r) = ReaderT (g . r)
424+
425+
waitCatchSTM = WrappedSTM . waitCatchSTM . unWrapAsync
426+
pollSTM = WrappedSTM . pollSTM . unWrapAsync
427+
428+
race (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race (ma r) (mb r)
429+
race_ (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race_ (ma r) (mb r)
430+
concurrently (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> concurrently (ma r) (mb r)
431+
432+
wait = lift . wait . unWrapAsync
433+
poll = lift . poll . unWrapAsync
434+
waitCatch = lift . waitCatch . unWrapAsync
435+
cancel = lift . cancel . unWrapAsync
436+
uninterruptibleCancel = lift . uninterruptibleCancel
437+
. unWrapAsync
438+
cancelWith = (lift .: cancelWith)
439+
. unWrapAsync
440+
waitAny = fmap (first WrappedAsync)
441+
. lift . waitAny
442+
. map unWrapAsync
443+
waitAnyCatch = fmap (first WrappedAsync)
444+
. lift . waitAnyCatch
445+
. map unWrapAsync
446+
waitAnyCancel = fmap (first WrappedAsync)
447+
. lift . waitAnyCancel
448+
. map unWrapAsync
449+
waitAnyCatchCancel = fmap (first WrappedAsync)
450+
. lift . waitAnyCatchCancel
451+
. map unWrapAsync
452+
waitEither = on (lift .: waitEither) unWrapAsync
453+
waitEitherCatch = on (lift .: waitEitherCatch) unWrapAsync
454+
waitEitherCancel = on (lift .: waitEitherCancel) unWrapAsync
455+
waitEitherCatchCancel = on (lift .: waitEitherCatchCancel) unWrapAsync
456+
waitEither_ = on (lift .: waitEither_) unWrapAsync
457+
waitBoth = on (lift .: waitBoth) unWrapAsync
458+
459+
460+
--
461+
-- Utilities
462+
--
463+
464+
(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
465+
(f .: g) x y = f (g x y)
466+
467+
468+
-- | A higher order version of 'Data.Function.on'
469+
--
470+
on :: (f a -> f b -> c)
471+
-> (forall x. g x -> f x)
472+
-> (g a -> g b -> c)
473+
on f g = \a b -> f (g a) (g b)

io-classes/src/Control/Monad/Class/MonadFork.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ module Control.Monad.Class.MonadFork
1414

1515
import qualified Control.Concurrent as IO
1616
import Control.Exception (AsyncException (ThreadKilled), Exception)
17-
import Control.Monad.Reader
17+
import Control.Monad.Reader (ReaderT (..), lift)
1818
import Data.Kind (Type)
1919
import qualified GHC.Conc.Sync as IO (labelThread)
2020

@@ -58,8 +58,8 @@ instance MonadFork IO where
5858
throwTo = IO.throwTo
5959
killThread = IO.killThread
6060

61-
instance MonadThread m => MonadThread (ReaderT e m) where
62-
type ThreadId (ReaderT e m) = ThreadId m
61+
instance MonadThread m => MonadThread (ReaderT r m) where
62+
type ThreadId (ReaderT r m) = ThreadId m
6363
myThreadId = lift myThreadId
6464
labelThread t l = lift (labelThread t l)
6565

0 commit comments

Comments
 (0)