Skip to content

Commit 4c06f38

Browse files
committed
MonadAsync: added various interfaces
* `asyncBound` * `asyncOn` * `asyncOnWithUnmask` * `withAsyncBound` * `withAsyncOn` * `withAsyncWithUnmask` * `withAsyncOnWithUnmask` * `compareAsyncs`
1 parent 486aeca commit 4c06f38

File tree

3 files changed

+92
-2
lines changed

3 files changed

+92
-2
lines changed

io-classes/src/Control/Monad/Class/MonadAsync.hs

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,19 @@ class ( MonadSTM m
5353
, MonadThread m
5454
) => MonadAsync m where
5555

56-
{-# MINIMAL async, asyncThreadId, cancel, cancelWith, asyncWithUnmask,
57-
waitCatchSTM, pollSTM #-}
56+
{-# MINIMAL async, asyncBound, asyncOn, asyncThreadId, cancel, cancelWith,
57+
asyncWithUnmask, asyncOnWithUnmask, waitCatchSTM, pollSTM #-}
5858

5959
-- | An asynchronous action
6060
type Async m = (async :: Type -> Type) | async -> m
6161

6262
async :: m a -> m (Async m a)
63+
asyncBound :: m a -> m (Async m a)
64+
asyncOn :: Int -> m a -> m (Async m a)
6365
asyncThreadId :: Async m a -> ThreadId m
6466
withAsync :: m a -> (Async m a -> m b) -> m b
67+
withAsyncBound :: m a -> (Async m a -> m b) -> m b
68+
withAsyncOn :: Int -> m a -> (Async m a -> m b) -> m b
6569

6670
waitSTM :: Async m a -> STM m a
6771
pollSTM :: Async m a -> STM m (Maybe (Either SomeException a))
@@ -144,9 +148,23 @@ class ( MonadSTM m
144148
concurrently_ :: m a -> m b -> m ()
145149

146150
asyncWithUnmask :: ((forall b . m b -> m b) -> m a) -> m (Async m a)
151+
asyncOnWithUnmask :: Int -> ((forall b . m b -> m b) -> m a) -> m (Async m a)
152+
withAsyncWithUnmask :: ((forall c. m c -> m c) -> m a) -> (Async m a -> m b) -> m b
153+
withAsyncOnWithUnmask :: Int -> ((forall c. m c -> m c) -> m a) -> (Async m a -> m b) -> m b
154+
155+
compareAsyncs :: Async m a -> Async m b -> Ordering
147156

148157
-- default implementations
149158
default withAsync :: MonadMask m => m a -> (Async m a -> m b) -> m b
159+
default withAsyncBound:: MonadMask m => m a -> (Async m a -> m b) -> m b
160+
default withAsyncOn :: MonadMask m => Int -> m a -> (Async m a -> m b) -> m b
161+
default withAsyncWithUnmask
162+
:: MonadMask m => ((forall c. m c -> m c) -> m a)
163+
-> (Async m a -> m b) -> m b
164+
default withAsyncOnWithUnmask
165+
:: MonadMask m => Int
166+
-> ((forall c. m c -> m c) -> m a)
167+
-> (Async m a -> m b) -> m b
150168
default uninterruptibleCancel
151169
:: MonadMask m => Async m a -> m ()
152170
default waitAnyCancel :: MonadThrow m => [Async m a] -> m (Async m a, a)
@@ -157,12 +175,35 @@ class ( MonadSTM m
157175
default waitEitherCatchCancel :: MonadThrow m => Async m a -> Async m b
158176
-> m (Either (Either SomeException a)
159177
(Either SomeException b))
178+
default compareAsyncs :: Ord (ThreadId m)
179+
=> Async m a -> Async m b -> Ordering
160180

161181
withAsync action inner = mask $ \restore -> do
162182
a <- async (restore action)
163183
restore (inner a)
164184
`finally` uninterruptibleCancel a
165185

186+
withAsyncBound action inner = mask $ \restore -> do
187+
a <- asyncBound (restore action)
188+
restore (inner a)
189+
`finally` uninterruptibleCancel a
190+
191+
withAsyncOn n action inner = mask $ \restore -> do
192+
a <- asyncOn n (restore action)
193+
restore (inner a)
194+
`finally` uninterruptibleCancel a
195+
196+
197+
withAsyncWithUnmask action inner = mask $ \restore -> do
198+
a <- asyncWithUnmask action
199+
restore (inner a)
200+
`finally` uninterruptibleCancel a
201+
202+
withAsyncOnWithUnmask n action inner = mask $ \restore -> do
203+
a <- asyncOnWithUnmask n action
204+
restore (inner a)
205+
`finally` uninterruptibleCancel a
206+
166207
wait = atomically . waitSTM
167208
poll = atomically . pollSTM
168209
waitCatch = atomically . waitCatchSTM
@@ -202,6 +243,8 @@ class ( MonadSTM m
202243

203244
concurrently_ left right = void $ concurrently left right
204245

246+
compareAsyncs a b = asyncThreadId a `compare` asyncThreadId b
247+
205248
-- | Similar to 'Async.Concurrently' but which works for any 'MonadAsync'
206249
-- instance.
207250
--
@@ -265,8 +308,12 @@ instance MonadAsync IO where
265308
type Async IO = Async.Async
266309

267310
async = Async.async
311+
asyncBound = Async.asyncBound
312+
asyncOn = Async.asyncOn
268313
asyncThreadId = Async.asyncThreadId
269314
withAsync = Async.withAsync
315+
withAsyncBound = Async.withAsyncBound
316+
withAsyncOn = Async.withAsyncOn
270317

271318
waitSTM = Async.waitSTM
272319
pollSTM = Async.pollSTM
@@ -303,6 +350,11 @@ instance MonadAsync IO where
303350
concurrently_ = Async.concurrently_
304351

305352
asyncWithUnmask = Async.asyncWithUnmask
353+
asyncOnWithUnmask = Async.asyncOnWithUnmask
354+
withAsyncWithUnmask = Async.withAsyncWithUnmask
355+
withAsyncOnWithUnmask = Async.withAsyncOnWithUnmask
356+
357+
compareAsyncs = Async.compareAsyncs
306358

307359

308360
--
@@ -410,15 +462,45 @@ instance ( MonadAsync m
410462
asyncThreadId (WrappedAsync a) = asyncThreadId a
411463

412464
async (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> async (ma r)
465+
asyncBound (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> asyncBound (ma r)
466+
asyncOn n (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> asyncOn n (ma r)
413467
withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r)
414468
$ \a -> runReaderT (f (WrappedAsync a)) r
469+
withAsyncBound (ReaderT ma) f = ReaderT $ \r -> withAsyncBound (ma r)
470+
$ \a -> runReaderT (f (WrappedAsync a)) r
471+
withAsyncOn n (ReaderT ma) f = ReaderT $ \r -> withAsyncOn n (ma r)
472+
$ \a -> runReaderT (f (WrappedAsync a)) r
473+
415474
asyncWithUnmask f = ReaderT $ \r -> fmap WrappedAsync
416475
$ asyncWithUnmask
417476
$ \unmask -> runReaderT (f (liftF unmask)) r
418477
where
419478
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
420479
liftF g (ReaderT r) = ReaderT (g . r)
421480

481+
asyncOnWithUnmask n f = ReaderT $ \r -> fmap WrappedAsync
482+
$ asyncOnWithUnmask n
483+
$ \unmask -> runReaderT (f (liftF unmask)) r
484+
where
485+
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
486+
liftF g (ReaderT r) = ReaderT (g . r)
487+
488+
withAsyncWithUnmask action f =
489+
ReaderT $ \r -> withAsyncWithUnmask (\unmask -> case action (liftF unmask) of
490+
ReaderT ma -> ma r)
491+
$ \a -> runReaderT (f (WrappedAsync a)) r
492+
where
493+
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
494+
liftF g (ReaderT r) = ReaderT (g . r)
495+
496+
withAsyncOnWithUnmask n action f =
497+
ReaderT $ \r -> withAsyncOnWithUnmask n (\unmask -> case action (liftF unmask) of
498+
ReaderT ma -> ma r)
499+
$ \a -> runReaderT (f (WrappedAsync a)) r
500+
where
501+
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
502+
liftF g (ReaderT r) = ReaderT (g . r)
503+
422504
waitCatchSTM = WrappedSTM . waitCatchSTM . unWrapAsync
423505
pollSTM = WrappedSTM . pollSTM . unWrapAsync
424506

io-classes/src/Control/Monad/Class/MonadFork.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class (Monad m, Eq (ThreadId m),
3232
class MonadThread m => MonadFork m where
3333

3434
forkIO :: m () -> m (ThreadId m)
35+
forkOn :: Int -> m () -> m (ThreadId m)
3536
forkIOWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
3637
throwTo :: Exception e => ThreadId m -> e -> m ()
3738

@@ -56,6 +57,7 @@ instance MonadThread IO where
5657

5758
instance MonadFork IO where
5859
forkIO = IO.forkIO
60+
forkOn = IO.forkOn
5961
forkIOWithUnmask = IO.forkIOWithUnmask
6062
throwTo = IO.throwTo
6163
killThread = IO.killThread
@@ -68,6 +70,7 @@ instance MonadThread m => MonadThread (ReaderT r m) where
6870

6971
instance MonadFork m => MonadFork (ReaderT e m) where
7072
forkIO (ReaderT f) = ReaderT $ \e -> forkIO (f e)
73+
forkOn n (ReaderT f) = ReaderT $ \e -> forkOn n (f e)
7174
forkIOWithUnmask k = ReaderT $ \e -> forkIOWithUnmask $ \restore ->
7275
let restore' :: ReaderT e m a -> ReaderT e m a
7376
restore' (ReaderT f) = ReaderT $ restore . f

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ instance MonadThread (IOSim s) where
369369

370370
instance MonadFork (IOSim s) where
371371
forkIO task = IOSim $ oneShot $ \k -> Fork task k
372+
forkOn _ task = IOSim $ oneShot $ \k -> Fork task k
372373
forkIOWithUnmask f = forkIO (f unblock)
373374
throwTo tid e = IOSim $ oneShot $ \k -> ThrowTo (toException e) tid (k ())
374375
yield = IOSim $ oneShot $ \k -> YieldSim (k ())
@@ -470,6 +471,9 @@ instance MonadAsync (IOSim s) where
470471
MonadSTM.labelTMVarIO var ("async-" ++ show tid)
471472
return (Async tid (MonadSTM.readTMVar var))
472473

474+
asyncOn _ = async
475+
asyncBound = async
476+
473477
asyncThreadId (Async tid _) = tid
474478

475479
waitCatchSTM (Async _ w) = w
@@ -479,6 +483,7 @@ instance MonadAsync (IOSim s) where
479483
cancelWith a@(Async tid _) e = throwTo tid e <* waitCatch a
480484

481485
asyncWithUnmask k = async (k unblock)
486+
asyncOnWithUnmask _ k = async (k unblock)
482487

483488
instance MonadST (IOSim s) where
484489
withLiftST f = f liftST

0 commit comments

Comments
 (0)