diff --git a/changelog.d/issue-1829 b/changelog.d/issue-1829 new file mode 100644 index 000000000..28d0f80cf --- /dev/null +++ b/changelog.d/issue-1829 @@ -0,0 +1,4 @@ +synopsis: Add MonadCatch and MonadMask instances to RouteResultT and DelayedIO +packages: servant-server +prs: #1830 +issues: #1829 diff --git a/servant-server/src/Servant/Server/Internal/DelayedIO.hs b/servant-server/src/Servant/Server/Internal/DelayedIO.hs index 933194909..5eba44c46 100644 --- a/servant-server/src/Servant/Server/Internal/DelayedIO.hs +++ b/servant-server/src/Servant/Server/Internal/DelayedIO.hs @@ -5,7 +5,7 @@ module Servant.Server.Internal.DelayedIO where import Control.Monad.Base (MonadBase (..)) -import Control.Monad.Catch (MonadThrow (..)) +import Control.Monad.Catch (MonadCatch (..), MonadMask, MonadThrow (..)) import Control.Monad.Reader (MonadReader (..), ReaderT (..), runReaderT) import Control.Monad.Trans (MonadIO (..), MonadTrans (..)) import Control.Monad.Trans.Control (MonadBaseControl (..)) @@ -30,7 +30,9 @@ newtype DelayedIO a = DelayedIO {runDelayedIO' :: ReaderT Request (ResourceT (Ro ( Applicative , Functor , Monad + , MonadCatch , MonadIO + , MonadMask , MonadReader Request , MonadResource , MonadThrow diff --git a/servant-server/src/Servant/Server/Internal/RouteResult.hs b/servant-server/src/Servant/Server/Internal/RouteResult.hs index cbf55a394..5427be129 100644 --- a/servant-server/src/Servant/Server/Internal/RouteResult.hs +++ b/servant-server/src/Servant/Server/Internal/RouteResult.hs @@ -6,9 +6,9 @@ module Servant.Server.Internal.RouteResult where -import Control.Monad (ap, liftM) +import Control.Monad (ap) import Control.Monad.Base (MonadBase (..)) -import Control.Monad.Catch (MonadThrow (..)) +import Control.Monad.Catch (ExitCase (..), MonadCatch (..), MonadMask (..), MonadThrow (..)) import Control.Monad.Trans (MonadIO (..), MonadTrans (..)) import Control.Monad.Trans.Control ( ComposeSt @@ -72,8 +72,48 @@ instance MonadBaseControl b m => MonadBaseControl b (RouteResultT m) where instance MonadTransControl RouteResultT where type StT RouteResultT a = RouteResult a - liftWith f = RouteResultT $ liftM return $ f runRouteResultT + liftWith f = RouteResultT (return <$> f runRouteResultT) restoreT = RouteResultT instance MonadThrow m => MonadThrow (RouteResultT m) where throwM = lift . throwM + +instance MonadCatch m => MonadCatch (RouteResultT m) where + catch (RouteResultT m) f = RouteResultT $ catch m (runRouteResultT . f) + +instance MonadMask m => MonadMask (RouteResultT m) where + mask f = RouteResultT $ mask $ \u -> runRouteResultT $ f (q u) + where + q + :: (m (RouteResult a) -> m (RouteResult a)) + -> RouteResultT m a + -> RouteResultT m a + q u (RouteResultT b) = RouteResultT (u b) + uninterruptibleMask f = RouteResultT $ uninterruptibleMask $ \u -> runRouteResultT $ f (q u) + where + q + :: (m (RouteResult a) -> m (RouteResult a)) + -> RouteResultT m a + -> RouteResultT m a + q u (RouteResultT b) = RouteResultT (u b) + + generalBracket acquire release use = RouteResultT $ do + (eb, ec) <- + generalBracket + (runRouteResultT acquire) + ( \resourceRoute exitCase -> case resourceRoute of + Fail e -> pure $ Fail e -- nothing to release, acquire didn't succeed + FailFatal e -> pure $ FailFatal e + Route resource -> case exitCase of + ExitCaseSuccess (Route b) -> runRouteResultT (release resource (ExitCaseSuccess b)) + ExitCaseException e -> runRouteResultT (release resource (ExitCaseException e)) + _ -> runRouteResultT (release resource ExitCaseAbort) + ) + ( \case + Fail e -> pure $ Fail e -- nothing to release, acquire didn't succeed + FailFatal e -> pure $ FailFatal e + Route resource -> runRouteResultT (use resource) + ) + -- The order in which we perform those two effects doesn't matter, + -- since the error message is the same regardless. + return ((,) <$> eb <*> ec)