Skip to content

Commit 7ac94bb

Browse files
Merge #887: Type.Infer: instance MonadTrans Solver: fx MonadPlus code
2 parents 4796681 + ae61038 commit 7ac94bb

File tree

1 file changed

+125
-112
lines changed

1 file changed

+125
-112
lines changed

src/Nix/Type/Infer.hs

Lines changed: 125 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE MultiWayIf #-}
12
{-# LANGUAGE CPP #-}
23
{-# LANGUAGE AllowAmbiguousTypes #-}
34
{-# LANGUAGE ConstraintKinds #-}
@@ -27,18 +28,36 @@ module Nix.Type.Infer
2728
)
2829
where
2930

30-
import Control.Applicative
31-
import Control.Arrow
32-
import Control.Monad.Catch
33-
import Control.Monad.Except
31+
import Control.Applicative ( Alternative
32+
, empty
33+
)
34+
import Data.Bifunctor ( Bifunctor(first) )
35+
import Control.Monad.Catch ( Exception(fromException, toException)
36+
, MonadThrow(..)
37+
, MonadCatch(..)
38+
)
39+
import Control.Monad.Except ( ExceptT
40+
, MonadError(..), runExceptT
41+
)
3442
#if !MIN_VERSION_base(4,13,0)
43+
import Prelude hiding ( fail )
3544
import Control.Monad.Fail
3645
#endif
37-
import Control.Monad.Logic
38-
import Control.Monad.Reader
46+
import Control.Monad.Logic hiding ( fail )
47+
import Control.Monad.Reader ( MonadReader(local)
48+
, ReaderT(..)
49+
, MonadFix
50+
)
3951
import Control.Monad.Ref
40-
import Control.Monad.ST
41-
import Control.Monad.State.Strict
52+
import Control.Monad.ST ( ST
53+
, runST
54+
)
55+
import Control.Monad.State.Strict ( modify
56+
, evalState
57+
, evalStateT
58+
, MonadState(put, get)
59+
, StateT(runStateT)
60+
)
4261
import Data.Fix ( foldFix )
4362
import Data.Foldable ( foldl'
4463
, foldrM
@@ -65,7 +84,7 @@ import Nix.Fresh
6584
import Nix.String
6685
import Nix.Scope
6786
import qualified Nix.Type.Assumption as As
68-
import Nix.Type.Env
87+
import Nix.Type.Env hiding ( empty )
6988
import qualified Nix.Type.Env as Env
7089
import Nix.Type.Type
7190
import Nix.Utils
@@ -79,8 +98,10 @@ import Nix.Var
7998
newtype InferT s m a =
8099
InferT
81100
{ getInfer ::
82-
ReaderT (Set.Set TVar, Scopes (InferT s m) (Judgment s))
83-
(StateT InferState (ExceptT InferError m)) a
101+
ReaderT
102+
(Set.Set TVar, Scopes (InferT s m) (Judgment s))
103+
(StateT InferState (ExceptT InferError m))
104+
a
84105
}
85106
deriving
86107
( Functor
@@ -236,25 +257,28 @@ inferType env ex = do
236257
Set.fromList (As.keys as) `Set.difference` Set.fromList (Env.keys env)
237258
unless (Set.null unbounds) $ typeError $ UnboundVariables
238259
(nub (Set.toList unbounds))
239-
let cs' =
240-
[ ExpInstConst t s
241-
| (x, ss) <- Env.toList env
242-
, s <- ss
243-
, t <- As.lookup x as
244-
]
260+
let
261+
cs' =
262+
[ ExpInstConst t s
263+
| (x, ss) <- Env.toList env
264+
, s <- ss
265+
, t <- As.lookup x as
266+
]
245267
inferState <- get
246-
let eres = (`evalState` inferState) $ runSolver $ do
268+
let
269+
eres = (`evalState` inferState) $ runSolver $
270+
do
247271
subst <- solve (cs <> cs')
248272
pure (subst, subst `apply` t)
249-
case eres of
250-
Left errs -> throwError $ TypeInferenceErrors errs
251-
Right xs -> pure xs
273+
either
274+
(throwError . TypeInferenceErrors)
275+
pure
276+
eres
252277

253278
-- | Solve for the toplevel type of an expression in a given environment
254279
inferExpr :: Env -> NExpr -> Either InferError [Scheme]
255-
inferExpr env ex = case runInfer (inferType env ex) of
256-
Left err -> Left err
257-
Right xs -> Right $ fmap (\(subst, ty) -> closeOver (subst `apply` ty)) xs
280+
inferExpr env ex =
281+
(fmap . fmap) (\(subst, ty) -> closeOver (subst `apply` ty)) $ runInfer $ inferType env ex
258282

259283
-- | Canonicalize and return the polymorphic toplevel type.
260284
closeOver :: Type -> Scheme
@@ -267,105 +291,92 @@ letters :: [String]
267291
letters = [1 ..] >>= flip replicateM ['a' .. 'z']
268292

269293
freshTVar :: MonadState InferState m => m TVar
270-
freshTVar = do
271-
s <- get
272-
put s { count = count s + 1 }
273-
pure $ TV (letters !! count s)
294+
freshTVar =
295+
do
296+
s <- get
297+
put s { count = count s + 1 }
298+
pure $ TV (letters !! count s)
274299

275300
fresh :: MonadState InferState m => m Type
276301
fresh = TVar <$> freshTVar
277302

278303
instantiate :: MonadState InferState m => Scheme -> m Type
279-
instantiate (Forall as t) = do
280-
as' <- traverse (const fresh) as
281-
let s = Subst $ Map.fromList $ zip as as'
282-
pure $ apply s t
304+
instantiate (Forall as t) =
305+
do
306+
as' <- traverse (const fresh) as
307+
let s = Subst $ Map.fromList $ zip as as'
308+
pure $ apply s t
283309

284310
generalize :: Set.Set TVar -> Type -> Scheme
285311
generalize free t = Forall as t
286-
where as = Set.toList $ ftv t `Set.difference` free
312+
where
313+
as = Set.toList $ ftv t `Set.difference` free
287314

288315
unops :: Type -> NUnaryOp -> [Constraint]
289-
unops u1 = \case
290-
NNot -> [EqConst u1 (typeFun [typeBool, typeBool])]
291-
NNeg ->
292-
[ EqConst
293-
u1
294-
(TMany [typeFun [typeInt, typeInt], typeFun [typeFloat, typeFloat]])
295-
]
316+
unops u1 op =
317+
[ EqConst u1
318+
(case op of
319+
NNot -> typeFun [typeBool , typeBool ]
320+
NNeg -> TMany [typeFun [typeInt, typeInt], typeFun [typeFloat, typeFloat]]
321+
)
322+
]
296323

297324
binops :: Type -> NBinaryOp -> [Constraint]
298-
binops u1 = \case
299-
NApp -> mempty -- this is handled separately
300-
301-
-- Equality tells you nothing about the types, because any two types are
302-
-- allowed.
303-
NEq -> mempty
304-
NNEq -> mempty
305-
306-
NGt -> inequality
307-
NGte -> inequality
308-
NLt -> inequality
309-
NLte -> inequality
310-
311-
NAnd -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
312-
NOr -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
313-
NImpl -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
314-
315-
NConcat -> [EqConst u1 (typeFun [typeList, typeList, typeList])]
316-
317-
NUpdate ->
318-
[ EqConst
319-
u1
320-
(TMany
321-
[ typeFun [typeSet, typeSet, typeSet]
322-
, typeFun [typeSet, typeNull, typeSet]
323-
, typeFun [typeNull, typeSet, typeSet]
324-
]
325-
)
326-
]
327-
328-
NPlus ->
329-
[ EqConst
330-
u1
331-
(TMany
332-
[ typeFun [typeInt, typeInt, typeInt]
333-
, typeFun [typeFloat, typeFloat, typeFloat]
334-
, typeFun [typeInt, typeFloat, typeFloat]
335-
, typeFun [typeFloat, typeInt, typeFloat]
336-
, typeFun [typeString, typeString, typeString]
337-
, typeFun [typePath, typePath, typePath]
338-
, typeFun [typeString, typeString, typePath]
339-
]
340-
)
341-
]
342-
NMinus -> arithmetic
343-
NMult -> arithmetic
344-
NDiv -> arithmetic
325+
binops u1 op =
326+
if
327+
-- NApp in fact is handled separately
328+
-- Equality tells nothing about the types, because any two types are allowed.
329+
| op `elem` [ NApp , NEq , NNEq ] -> mempty
330+
| op `elem` [ NGt , NGte , NLt , NLte ] -> inequality
331+
| op `elem` [ NAnd , NOr , NImpl ] -> gate
332+
| op == NConcat -> concatenation
333+
| op `elem` [ NMinus, NMult, NDiv ] -> arithmetic
334+
| op == NUpdate -> rUnion
335+
| op == NPlus -> addition
336+
| otherwise -> fail "GHC so far can not infer that this pattern match is full, so make it happy."
337+
345338
where
339+
340+
gate = eqCnst [typeBool, typeBool, typeBool]
341+
concatenation = eqCnst [typeList, typeList, typeList]
342+
343+
eqCnst l = [EqConst u1 (typeFun l)]
344+
346345
inequality =
347-
[ EqConst
348-
u1
349-
(TMany
350-
[ typeFun [typeInt, typeInt, typeBool]
351-
, typeFun [typeFloat, typeFloat, typeBool]
352-
, typeFun [typeInt, typeFloat, typeBool]
353-
, typeFun [typeFloat, typeInt, typeBool]
354-
]
355-
)
356-
]
346+
eqCnstMtx
347+
[ [typeInt , typeInt , typeBool]
348+
, [typeFloat, typeFloat, typeBool]
349+
, [typeInt , typeFloat, typeBool]
350+
, [typeFloat, typeInt , typeBool]
351+
]
357352

358353
arithmetic =
359-
[ EqConst
360-
u1
361-
(TMany
362-
[ typeFun [typeInt, typeInt, typeInt]
363-
, typeFun [typeFloat, typeFloat, typeFloat]
364-
, typeFun [typeInt, typeFloat, typeFloat]
365-
, typeFun [typeFloat, typeInt, typeFloat]
366-
]
367-
)
368-
]
354+
eqCnstMtx
355+
[ [typeInt , typeInt , typeInt ]
356+
, [typeFloat, typeFloat, typeFloat]
357+
, [typeInt , typeFloat, typeFloat]
358+
, [typeFloat, typeInt , typeFloat]
359+
]
360+
361+
rUnion =
362+
eqCnstMtx
363+
[ [typeSet , typeSet , typeSet]
364+
, [typeSet , typeNull, typeSet]
365+
, [typeNull, typeSet , typeSet]
366+
]
367+
368+
addition =
369+
eqCnstMtx
370+
[ [typeInt , typeInt , typeInt ]
371+
, [typeFloat , typeFloat , typeFloat ]
372+
, [typeInt , typeFloat , typeFloat ]
373+
, [typeFloat , typeInt , typeFloat ]
374+
, [typeString, typeString, typeString]
375+
, [typePath , typePath , typePath ]
376+
, [typeString, typeString, typePath ]
377+
]
378+
379+
eqCnstMtx mtx = [EqConst u1 (TMany (typeFun <$> mtx))]
369380

370381
liftInfer :: Monad m => m a -> InferT s m a
371382
liftInfer = InferT . lift . lift . lift
@@ -377,10 +388,12 @@ instance MonadRef m => MonadRef (InferT s m) where
377388
writeRef x y = liftInfer $ writeRef x y
378389

379390
instance MonadAtomicRef m => MonadAtomicRef (InferT s m) where
380-
atomicModifyRef x f = liftInfer $ do
381-
res <- snd . f <$> readRef x
382-
_ <- modifyRef x (fst . f)
383-
pure res
391+
atomicModifyRef x f =
392+
liftInfer $
393+
do
394+
res <- snd . f <$> readRef x
395+
_ <- modifyRef x (fst . f)
396+
pure res
384397

385398
-- newtype JThunkT s m = JThunk (NThunkF (InferT s m) (Judgment s))
386399

@@ -669,7 +682,7 @@ instance MonadTrans Solver where
669682
lift = Solver . lift . lift
670683

671684
instance Monad m => MonadError TypeError (Solver m) where
672-
throwError err = Solver $ lift (modify (err :)) *> mzero
685+
throwError err = Solver $ lift (modify (err :)) *> empty
673686
catchError _ _ = error "This is never used"
674687

675688
runSolver :: Monad m => Solver m a -> m (Either [TypeError] [a])

0 commit comments

Comments
 (0)