Skip to content

Commit a74b49a

Browse files
committed
Add MonadCatch and MonadMask instances to RouteResultT and DelayedIO
Fix #1829
1 parent bac8d0a commit a74b49a

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

servant-server/src/Servant/Server/Internal/DelayedIO.hs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
module Servant.Server.Internal.DelayedIO where
66

77
import Control.Monad.Base (MonadBase (..))
8-
import Control.Monad.Catch (MonadThrow (..))
8+
import Control.Monad.Catch (MonadThrow (..), MonadCatch(..), MonadMask)
99
import Control.Monad.Reader (MonadReader (..), ReaderT (..), runReaderT)
1010
import Control.Monad.Trans (MonadIO (..), MonadTrans (..))
1111
import Control.Monad.Trans.Control (MonadBaseControl (..))
@@ -34,6 +34,8 @@ newtype DelayedIO a = DelayedIO {runDelayedIO' :: ReaderT Request (ResourceT (Ro
3434
, MonadReader Request
3535
, MonadResource
3636
, MonadThrow
37+
, MonadCatch
38+
, MonadMask
3739
)
3840

3941
instance MonadBase IO DelayedIO where
@@ -53,6 +55,7 @@ instance MonadBaseControl IO DelayedIO where
5355
runInBase (runInternalState (runReaderT (runDelayedIO' x) req) s)
5456
restoreM = DelayedIO . lift . withInternalState . const . restoreM
5557

58+
5659
runDelayedIO :: DelayedIO a -> Request -> ResourceT IO (RouteResult a)
5760
runDelayedIO m req = transResourceT runRouteResultT $ runReaderT (runDelayedIO' m) req
5861

servant-server/src/Servant/Server/Internal/RouteResult.hs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
module Servant.Server.Internal.RouteResult where
88

9-
import Control.Monad (ap, liftM)
9+
import Control.Monad (ap)
1010
import Control.Monad.Base (MonadBase (..))
11-
import Control.Monad.Catch (MonadThrow (..))
11+
import Control.Monad.Catch (ExitCase (..), MonadCatch (..), MonadMask (..), MonadThrow (..))
1212
import Control.Monad.Trans (MonadIO (..), MonadTrans (..))
1313
import Control.Monad.Trans.Control
1414
( ComposeSt
@@ -72,8 +72,48 @@ instance MonadBaseControl b m => MonadBaseControl b (RouteResultT m) where
7272

7373
instance MonadTransControl RouteResultT where
7474
type StT RouteResultT a = RouteResult a
75-
liftWith f = RouteResultT $ liftM return $ f runRouteResultT
75+
liftWith f = RouteResultT (return <$> f runRouteResultT)
7676
restoreT = RouteResultT
7777

7878
instance MonadThrow m => MonadThrow (RouteResultT m) where
7979
throwM = lift . throwM
80+
81+
instance MonadCatch m => MonadCatch (RouteResultT m) where
82+
catch (RouteResultT m) f = RouteResultT $ catch m (runRouteResultT . f)
83+
84+
instance MonadMask m => MonadMask (RouteResultT m) where
85+
mask f = RouteResultT $ mask $ \u -> runRouteResultT $ f (q u)
86+
where
87+
q
88+
:: (m (RouteResult a) -> m (RouteResult a))
89+
-> RouteResultT m a
90+
-> RouteResultT m a
91+
q u (RouteResultT b) = RouteResultT (u b)
92+
uninterruptibleMask f = RouteResultT $ uninterruptibleMask $ \u -> runRouteResultT $ f (q u)
93+
where
94+
q
95+
:: (m (RouteResult a) -> m (RouteResult a))
96+
-> RouteResultT m a
97+
-> RouteResultT m a
98+
q u (RouteResultT b) = RouteResultT (u b)
99+
100+
generalBracket acquire release use = RouteResultT $ do
101+
(eb, ec) <-
102+
generalBracket
103+
(runRouteResultT acquire)
104+
( \resourceRoute exitCase -> case resourceRoute of
105+
Fail e -> pure $ Fail e -- nothing to release, acquire didn't succeed
106+
FailFatal e -> pure $ FailFatal e
107+
Route resource -> case exitCase of
108+
ExitCaseSuccess (Route b) -> runRouteResultT (release resource (ExitCaseSuccess b))
109+
ExitCaseException e -> runRouteResultT (release resource (ExitCaseException e))
110+
_ -> runRouteResultT (release resource ExitCaseAbort)
111+
)
112+
( \case
113+
Fail e -> pure $ Fail e -- nothing to release, acquire didn't succeed
114+
FailFatal e -> pure $ FailFatal e
115+
Route resource -> runRouteResultT (use resource)
116+
)
117+
-- The order in which we perform those two effects doesn't matter,
118+
-- since the error message is the same regardless.
119+
return ((,) <$> eb <*> ec)

0 commit comments

Comments
 (0)