Skip to content

Commit 0a587e2

Browse files
committed
io-classes: MonadAsync instance for ReaderT
1 parent 80eca9e commit 0a587e2

File tree

1 file changed

+82
-1
lines changed

1 file changed

+82
-1
lines changed

io-classes/src/Control/Monad/Class/MonadAsync.hs

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
{-# LANGUAGE DataKinds #-}
12
{-# LANGUAGE DefaultSignatures #-}
23
{-# LANGUAGE FlexibleContexts #-}
4+
{-# LANGUAGE GADTs #-}
35
{-# LANGUAGE MultiParamTypeClasses #-}
46
{-# LANGUAGE QuantifiedConstraints #-}
57
{-# LANGUAGE RankNTypes #-}
68
{-# LANGUAGE ScopedTypeVariables #-}
79
{-# LANGUAGE TypeApplications #-}
810
{-# LANGUAGE TypeFamilies #-}
911
{-# LANGUAGE TypeFamilyDependencies #-}
10-
12+
-- MonadAsync's ReaderT instance is undecidable.
13+
{-# LANGUAGE UndecidableInstances #-}
1114
module Control.Monad.Class.MonadAsync
1215
( MonadAsync (..)
1316
, AsyncCancelled (..)
@@ -34,10 +37,14 @@ import Control.Monad.Class.MonadSTM
3437
import Control.Monad.Class.MonadThrow
3538
import Control.Monad.Class.MonadTimer
3639

40+
import Control.Monad.Trans (lift)
41+
import Control.Monad.Reader (ReaderT (..))
42+
3743
import Control.Concurrent.Async (AsyncCancelled (..))
3844
import qualified Control.Concurrent.Async as Async
3945
import qualified Control.Exception as E
4046

47+
import Data.Bifunctor (first)
4148
import Data.Foldable (fold)
4249
import Data.Functor (void)
4350
import Data.Kind (Type)
@@ -390,3 +397,77 @@ forkRepeat label action =
390397

391398
tryAll :: MonadCatch m => m a -> m (Either SomeException a)
392399
tryAll = try
400+
401+
402+
--
403+
-- ReaderT instance
404+
--
405+
406+
newtype WrappedAsync r (m :: Type -> Type) a =
407+
WrappedAsync { unWrapAsync :: Async m a }
408+
409+
instance ( MonadAsync m
410+
, MonadCatch (STM m)
411+
) => MonadAsync (ReaderT r m) where
412+
type Async (ReaderT r m) = WrappedAsync r m
413+
asyncThreadId (WrappedAsync a) = asyncThreadId a
414+
415+
async (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> async (ma r)
416+
withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r)
417+
$ \a -> runReaderT (f (WrappedAsync a)) r
418+
asyncWithUnmask f = ReaderT $ \r -> fmap WrappedAsync
419+
$ asyncWithUnmask
420+
$ \unmask -> runReaderT (f (liftF unmask)) r
421+
where
422+
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
423+
liftF g (ReaderT r) = ReaderT (g . r)
424+
425+
waitCatchSTM = WrappedSTM . waitCatchSTM . unWrapAsync
426+
pollSTM = WrappedSTM . pollSTM . unWrapAsync
427+
428+
race (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race (ma r) (mb r)
429+
race_ (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race_ (ma r) (mb r)
430+
concurrently (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> concurrently (ma r) (mb r)
431+
432+
wait = lift . wait . unWrapAsync
433+
poll = lift . poll . unWrapAsync
434+
waitCatch = lift . waitCatch . unWrapAsync
435+
cancel = lift . cancel . unWrapAsync
436+
uninterruptibleCancel = lift . uninterruptibleCancel
437+
. unWrapAsync
438+
cancelWith = (lift .: cancelWith)
439+
. unWrapAsync
440+
waitAny = fmap (first WrappedAsync)
441+
. lift . waitAny
442+
. map unWrapAsync
443+
waitAnyCatch = fmap (first WrappedAsync)
444+
. lift . waitAnyCatch
445+
. map unWrapAsync
446+
waitAnyCancel = fmap (first WrappedAsync)
447+
. lift . waitAnyCancel
448+
. map unWrapAsync
449+
waitAnyCatchCancel = fmap (first WrappedAsync)
450+
. lift . waitAnyCatchCancel
451+
. map unWrapAsync
452+
waitEither = on (lift .: waitEither) unWrapAsync
453+
waitEitherCatch = on (lift .: waitEitherCatch) unWrapAsync
454+
waitEitherCancel = on (lift .: waitEitherCancel) unWrapAsync
455+
waitEitherCatchCancel = on (lift .: waitEitherCatchCancel) unWrapAsync
456+
waitEither_ = on (lift .: waitEither_) unWrapAsync
457+
waitBoth = on (lift .: waitBoth) unWrapAsync
458+
459+
460+
--
461+
-- Utilities
462+
--
463+
464+
(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
465+
(f .: g) x y = f (g x y)
466+
467+
468+
-- | A higher order version of 'Data.Function.on'
469+
--
470+
on :: (f a -> f b -> c)
471+
-> (forall x. g x -> f x)
472+
-> (g a -> g b -> c)
473+
on f g = \a b -> f (g a) (g b)

0 commit comments

Comments
 (0)