Skip to content

Commit 736aacb

Browse files
committed
MonadAsync: linking functions
Added: * `link2` * `link2Only` Fixes IntersectMBO/ouroboros-network#2650
1 parent 4c06f38 commit 736aacb

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ ones that come from `base`, `async`, or `excpetions` packages:
8787
* `Handler` (origin: `base`)
8888
* `MaskingState` (origin: `base`)
8989
* `Concurrently` (origin: `async`)
90-
* `ExceptionInLinkedThread` (origin: `async`)
90+
* `ExceptionInLinkedThread` (origin: `async`): `io-class`es version does not
91+
store `Async`
9192
* `ExitCase` (origin: `exceptions`)
9293

9394

io-classes/src/Control/Monad/Class/MonadAsync.hs

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE CPP #-}
12
{-# LANGUAGE DataKinds #-}
23
{-# LANGUAGE DefaultSignatures #-}
34
{-# LANGUAGE FlexibleContexts #-}
@@ -16,8 +17,10 @@ module Control.Monad.Class.MonadAsync
1617
, AsyncCancelled (..)
1718
, ExceptionInLinkedThread (..)
1819
, link
19-
, linkTo
2020
, linkOnly
21+
, link2
22+
, link2Only
23+
, linkTo
2124
, linkToOnly
2225
, mapConcurrently
2326
, forConcurrently
@@ -365,11 +368,12 @@ instance MonadAsync IO where
365368
-- We don't use the implementation of linking from 'Control.Concurrent.Async'
366369
-- directly because:
367370
--
368-
-- 1. We need a generalized form of linking that links an async to an arbitrary
369-
-- thread ('linkTo')
370-
-- 2. If we /did/ use the real implementation, then the mock implementation and
371+
-- 1. If we /did/ use the real implementation, then the mock implementation and
371372
-- the real implementation would not be able to throw the same exception,
372373
-- because the exception type used by the real implementation is
374+
-- 2. We need a generalized form of linking that links an async to an arbitrary
375+
-- thread ('linkTo'), which is exposed only if cabal flag `+non-standard` is
376+
-- used.
373377
--
374378
-- > data ExceptionInLinkedThread =
375379
-- > forall a . ExceptionInLinkedThread (Async a) SomeException
@@ -399,6 +403,35 @@ instance Exception ExceptionInLinkedThread where
399403
fromException = E.asyncExceptionFromException
400404
toException = E.asyncExceptionToException
401405

406+
link :: (MonadAsync m, MonadFork m, MonadMask m)
407+
=> Async m a -> m ()
408+
link = linkOnly (not . isCancel)
409+
410+
linkOnly :: forall m a. (MonadAsync m, MonadFork m, MonadMask m)
411+
=> (SomeException -> Bool) -> Async m a -> m ()
412+
linkOnly shouldThrow a = do
413+
me <- myThreadId
414+
linkToOnly me shouldThrow a
415+
416+
link2 :: (MonadAsync m, MonadFork m, MonadMask m)
417+
=> Async m a -> Async m b -> m ()
418+
link2 = link2Only (not . isCancel)
419+
420+
link2Only :: (MonadAsync m, MonadFork m, MonadMask m)
421+
=> (SomeException -> Bool) -> Async m a -> Async m b -> m ()
422+
link2Only shouldThrow left right =
423+
void $ forkRepeat ("link2Only " <> show (tl, tr)) $ do
424+
r <- waitEitherCatch left right
425+
case r of
426+
Left (Left e) | shouldThrow e ->
427+
throwTo tr (ExceptionInLinkedThread (show tl) e)
428+
Right (Left e) | shouldThrow e ->
429+
throwTo tl (ExceptionInLinkedThread (show tr) e)
430+
_ -> return ()
431+
where
432+
tl = asyncThreadId left
433+
tr = asyncThreadId right
434+
402435
-- | Generalization of 'link' that links an async to an arbitrary thread.
403436
linkTo :: (MonadAsync m, MonadFork m, MonadMask m)
404437
=> ThreadId m -> Async m a -> m ()
@@ -420,16 +453,6 @@ linkToOnly tid shouldThrow a = do
420453
exceptionInLinkedThread =
421454
ExceptionInLinkedThread (show linkedThreadId)
422455

423-
link :: (MonadAsync m, MonadFork m, MonadMask m)
424-
=> Async m a -> m ()
425-
link = linkOnly (not . isCancel)
426-
427-
linkOnly :: forall m a. (MonadAsync m, MonadFork m, MonadMask m)
428-
=> (SomeException -> Bool) -> Async m a -> m ()
429-
linkOnly shouldThrow a = do
430-
me <- myThreadId
431-
linkToOnly me shouldThrow a
432-
433456
isCancel :: SomeException -> Bool
434457
isCancel e
435458
| Just AsyncCancelled <- fromException e = True
@@ -457,6 +480,8 @@ newtype WrappedAsync r (m :: Type -> Type) a =
457480

458481
instance ( MonadAsync m
459482
, MonadCatch (STM m)
483+
, MonadFork m
484+
, MonadMask m
460485
) => MonadAsync (ReaderT r m) where
461486
type Async (ReaderT r m) = WrappedAsync r m
462487
asyncThreadId (WrappedAsync a) = asyncThreadId a

0 commit comments

Comments
 (0)