1+ {-# LANGUAGE MultiWayIf #-}
12{-# LANGUAGE CPP #-}
23{-# LANGUAGE AllowAmbiguousTypes #-}
34{-# LANGUAGE ConstraintKinds #-}
@@ -27,18 +28,36 @@ module Nix.Type.Infer
2728 )
2829where
2930
30- import Control.Applicative
31- import Control.Arrow
32- import Control.Monad.Catch
33- import Control.Monad.Except
31+ import Control.Applicative ( Alternative
32+ , empty
33+ )
34+ import Data.Bifunctor ( Bifunctor (first ) )
35+ import Control.Monad.Catch ( Exception (fromException , toException )
36+ , MonadThrow (.. )
37+ , MonadCatch (.. )
38+ )
39+ import Control.Monad.Except ( ExceptT
40+ , MonadError (.. ), runExceptT
41+ )
3442#if !MIN_VERSION_base(4,13,0)
43+ import Prelude hiding ( fail )
3544import Control.Monad.Fail
3645#endif
37- import Control.Monad.Logic
38- import Control.Monad.Reader
46+ import Control.Monad.Logic hiding ( fail )
47+ import Control.Monad.Reader ( MonadReader (local )
48+ , ReaderT (.. )
49+ , MonadFix
50+ )
3951import Control.Monad.Ref
40- import Control.Monad.ST
41- import Control.Monad.State.Strict
52+ import Control.Monad.ST ( ST
53+ , runST
54+ )
55+ import Control.Monad.State.Strict ( modify
56+ , evalState
57+ , evalStateT
58+ , MonadState (put , get )
59+ , StateT (runStateT )
60+ )
4261import Data.Fix ( foldFix )
4362import Data.Foldable ( foldl'
4463 , foldrM
@@ -65,7 +84,7 @@ import Nix.Fresh
6584import Nix.String
6685import Nix.Scope
6786import qualified Nix.Type.Assumption as As
68- import Nix.Type.Env
87+ import Nix.Type.Env hiding ( empty )
6988import qualified Nix.Type.Env as Env
7089import Nix.Type.Type
7190import Nix.Utils
@@ -79,8 +98,10 @@ import Nix.Var
7998newtype InferT s m a =
8099 InferT
81100 { getInfer ::
82- ReaderT (Set. Set TVar , Scopes (InferT s m ) (Judgment s ))
83- (StateT InferState (ExceptT InferError m )) a
101+ ReaderT
102+ (Set. Set TVar , Scopes (InferT s m ) (Judgment s ))
103+ (StateT InferState (ExceptT InferError m ))
104+ a
84105 }
85106 deriving
86107 ( Functor
@@ -236,25 +257,28 @@ inferType env ex = do
236257 Set. fromList (As. keys as) `Set.difference` Set. fromList (Env. keys env)
237258 unless (Set. null unbounds) $ typeError $ UnboundVariables
238259 (nub (Set. toList unbounds))
239- let cs' =
240- [ ExpInstConst t s
241- | (x, ss) <- Env. toList env
242- , s <- ss
243- , t <- As. lookup x as
244- ]
260+ let
261+ cs' =
262+ [ ExpInstConst t s
263+ | (x, ss) <- Env. toList env
264+ , s <- ss
265+ , t <- As. lookup x as
266+ ]
245267 inferState <- get
246- let eres = (`evalState` inferState) $ runSolver $ do
268+ let
269+ eres = (`evalState` inferState) $ runSolver $
270+ do
247271 subst <- solve (cs <> cs')
248272 pure (subst, subst `apply` t)
249- case eres of
250- Left errs -> throwError $ TypeInferenceErrors errs
251- Right xs -> pure xs
273+ either
274+ (throwError . TypeInferenceErrors )
275+ pure
276+ eres
252277
253278-- | Solve for the toplevel type of an expression in a given environment
254279inferExpr :: Env -> NExpr -> Either InferError [Scheme ]
255- inferExpr env ex = case runInfer (inferType env ex) of
256- Left err -> Left err
257- Right xs -> Right $ fmap (\ (subst, ty) -> closeOver (subst `apply` ty)) xs
280+ inferExpr env ex =
281+ (fmap . fmap ) (\ (subst, ty) -> closeOver (subst `apply` ty)) $ runInfer $ inferType env ex
258282
259283-- | Canonicalize and return the polymorphic toplevel type.
260284closeOver :: Type -> Scheme
@@ -267,105 +291,92 @@ letters :: [String]
267291letters = [1 .. ] >>= flip replicateM [' a' .. ' z' ]
268292
269293freshTVar :: MonadState InferState m => m TVar
270- freshTVar = do
271- s <- get
272- put s { count = count s + 1 }
273- pure $ TV (letters !! count s)
294+ freshTVar =
295+ do
296+ s <- get
297+ put s { count = count s + 1 }
298+ pure $ TV (letters !! count s)
274299
275300fresh :: MonadState InferState m => m Type
276301fresh = TVar <$> freshTVar
277302
278303instantiate :: MonadState InferState m => Scheme -> m Type
279- instantiate (Forall as t) = do
280- as' <- traverse (const fresh) as
281- let s = Subst $ Map. fromList $ zip as as'
282- pure $ apply s t
304+ instantiate (Forall as t) =
305+ do
306+ as' <- traverse (const fresh) as
307+ let s = Subst $ Map. fromList $ zip as as'
308+ pure $ apply s t
283309
284310generalize :: Set. Set TVar -> Type -> Scheme
285311generalize free t = Forall as t
286- where as = Set. toList $ ftv t `Set.difference` free
312+ where
313+ as = Set. toList $ ftv t `Set.difference` free
287314
288315unops :: Type -> NUnaryOp -> [Constraint ]
289- unops u1 = \ case
290- NNot -> [ EqConst u1 (typeFun [typeBool, typeBool])]
291- NNeg ->
292- [ EqConst
293- u1
294- ( TMany [typeFun [typeInt, typeInt], typeFun [typeFloat, typeFloat]] )
295- ]
316+ unops u1 op =
317+ [ EqConst u1
318+ ( case op of
319+ NNot -> typeFun [typeBool , typeBool ]
320+ NNeg -> TMany [typeFun [typeInt, typeInt], typeFun [typeFloat, typeFloat]]
321+ )
322+ ]
296323
297324binops :: Type -> NBinaryOp -> [Constraint ]
298- binops u1 = \ case
299- NApp -> mempty -- this is handled separately
300-
301- -- Equality tells you nothing about the types, because any two types are
302- -- allowed.
303- NEq -> mempty
304- NNEq -> mempty
305-
306- NGt -> inequality
307- NGte -> inequality
308- NLt -> inequality
309- NLte -> inequality
310-
311- NAnd -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
312- NOr -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
313- NImpl -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
314-
315- NConcat -> [EqConst u1 (typeFun [typeList, typeList, typeList])]
316-
317- NUpdate ->
318- [ EqConst
319- u1
320- (TMany
321- [ typeFun [typeSet, typeSet, typeSet]
322- , typeFun [typeSet, typeNull, typeSet]
323- , typeFun [typeNull, typeSet, typeSet]
324- ]
325- )
326- ]
327-
328- NPlus ->
329- [ EqConst
330- u1
331- (TMany
332- [ typeFun [typeInt, typeInt, typeInt]
333- , typeFun [typeFloat, typeFloat, typeFloat]
334- , typeFun [typeInt, typeFloat, typeFloat]
335- , typeFun [typeFloat, typeInt, typeFloat]
336- , typeFun [typeString, typeString, typeString]
337- , typeFun [typePath, typePath, typePath]
338- , typeFun [typeString, typeString, typePath]
339- ]
340- )
341- ]
342- NMinus -> arithmetic
343- NMult -> arithmetic
344- NDiv -> arithmetic
325+ binops u1 op =
326+ if
327+ -- NApp in fact is handled separately
328+ -- Equality tells nothing about the types, because any two types are allowed.
329+ | op `elem` [ NApp , NEq , NNEq ] -> mempty
330+ | op `elem` [ NGt , NGte , NLt , NLte ] -> inequality
331+ | op `elem` [ NAnd , NOr , NImpl ] -> gate
332+ | op == NConcat -> concatenation
333+ | op `elem` [ NMinus , NMult , NDiv ] -> arithmetic
334+ | op == NUpdate -> rUnion
335+ | op == NPlus -> addition
336+ | otherwise -> fail " GHC so far can not infer that this pattern match is full, so make it happy."
337+
345338 where
339+
340+ gate = eqCnst [typeBool, typeBool, typeBool]
341+ concatenation = eqCnst [typeList, typeList, typeList]
342+
343+ eqCnst l = [EqConst u1 (typeFun l)]
344+
346345 inequality =
347- [ EqConst
348- u1
349- (TMany
350- [ typeFun [typeInt, typeInt, typeBool]
351- , typeFun [typeFloat, typeFloat, typeBool]
352- , typeFun [typeInt, typeFloat, typeBool]
353- , typeFun [typeFloat, typeInt, typeBool]
354- ]
355- )
356- ]
346+ eqCnstMtx
347+ [ [typeInt , typeInt , typeBool]
348+ , [typeFloat, typeFloat, typeBool]
349+ , [typeInt , typeFloat, typeBool]
350+ , [typeFloat, typeInt , typeBool]
351+ ]
357352
358353 arithmetic =
359- [ EqConst
360- u1
361- (TMany
362- [ typeFun [typeInt, typeInt, typeInt]
363- , typeFun [typeFloat, typeFloat, typeFloat]
364- , typeFun [typeInt, typeFloat, typeFloat]
365- , typeFun [typeFloat, typeInt, typeFloat]
366- ]
367- )
368- ]
354+ eqCnstMtx
355+ [ [typeInt , typeInt , typeInt ]
356+ , [typeFloat, typeFloat, typeFloat]
357+ , [typeInt , typeFloat, typeFloat]
358+ , [typeFloat, typeInt , typeFloat]
359+ ]
360+
361+ rUnion =
362+ eqCnstMtx
363+ [ [typeSet , typeSet , typeSet]
364+ , [typeSet , typeNull, typeSet]
365+ , [typeNull, typeSet , typeSet]
366+ ]
367+
368+ addition =
369+ eqCnstMtx
370+ [ [typeInt , typeInt , typeInt ]
371+ , [typeFloat , typeFloat , typeFloat ]
372+ , [typeInt , typeFloat , typeFloat ]
373+ , [typeFloat , typeInt , typeFloat ]
374+ , [typeString, typeString, typeString]
375+ , [typePath , typePath , typePath ]
376+ , [typeString, typeString, typePath ]
377+ ]
378+
379+ eqCnstMtx mtx = [EqConst u1 (TMany (typeFun <$> mtx))]
369380
370381liftInfer :: Monad m => m a -> InferT s m a
371382liftInfer = InferT . lift . lift . lift
@@ -377,10 +388,12 @@ instance MonadRef m => MonadRef (InferT s m) where
377388 writeRef x y = liftInfer $ writeRef x y
378389
379390instance MonadAtomicRef m => MonadAtomicRef (InferT s m ) where
380- atomicModifyRef x f = liftInfer $ do
381- res <- snd . f <$> readRef x
382- _ <- modifyRef x (fst . f)
383- pure res
391+ atomicModifyRef x f =
392+ liftInfer $
393+ do
394+ res <- snd . f <$> readRef x
395+ _ <- modifyRef x (fst . f)
396+ pure res
384397
385398-- newtype JThunkT s m = JThunk (NThunkF (InferT s m) (Judgment s))
386399
@@ -669,7 +682,7 @@ instance MonadTrans Solver where
669682 lift = Solver . lift . lift
670683
671684instance Monad m => MonadError TypeError (Solver m ) where
672- throwError err = Solver $ lift (modify (err : )) *> mzero
685+ throwError err = Solver $ lift (modify (err : )) *> empty
673686 catchError _ _ = error " This is never used"
674687
675688runSolver :: Monad m => Solver m a -> m (Either [TypeError ] [a ])
0 commit comments