Skip to content

Commit 884e16a

Browse files
committed
Type.Infer: refactor
1 parent 84d9625 commit 884e16a

File tree

1 file changed

+74
-55
lines changed

1 file changed

+74
-55
lines changed

src/Nix/Type/Infer.hs

Lines changed: 74 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,28 @@ import Nix.Utils
7272
import Nix.Value.Monad
7373
import Nix.Var
7474

75-
---------------------------------------------------------------------------------
75+
7676
-- * Classes
77-
---------------------------------------------------------------------------------
7877

7978
-- | Inference monad
80-
newtype InferT s m a = InferT
79+
newtype InferT s m a =
80+
InferT
8181
{ getInfer ::
8282
ReaderT (Set.Set TVar, Scopes (InferT s m) (Judgment s))
8383
(StateT InferState (ExceptT InferError m)) a
8484
}
8585
deriving
86-
( Functor
87-
, Applicative
88-
, Alternative
89-
, Monad
90-
, MonadPlus
91-
, MonadFix
92-
, MonadReader (Set.Set TVar, Scopes (InferT s m) (Judgment s))
93-
, MonadFail
94-
, MonadState InferState
95-
, MonadError InferError
96-
)
86+
( Functor
87+
, Applicative
88+
, Alternative
89+
, Monad
90+
, MonadPlus
91+
, MonadFix
92+
, MonadReader (Set.Set TVar, Scopes (InferT s m) (Judgment s))
93+
, MonadFail
94+
, MonadState InferState
95+
, MonadError InferError
96+
)
9797

9898
instance MonadTrans (InferT s) where
9999
lift = InferT . lift . lift . lift
@@ -109,10 +109,10 @@ initInfer :: InferState
109109
initInfer = InferState { count = 0 }
110110

111111
data Constraint
112-
= EqConst Type Type
113-
| ExpInstConst Type Scheme
114-
| ImpInstConst Type (Set.Set TVar) Type
115-
deriving (Show, Eq, Ord)
112+
= EqConst Type Type
113+
| ExpInstConst Type Scheme
114+
| ImpInstConst Type (Set.Set TVar) Type
115+
deriving (Show, Eq, Ord)
116116

117117
newtype Subst = Subst (Map TVar Type)
118118
deriving (Eq, Ord, Show, Semigroup, Monoid)
@@ -179,10 +179,9 @@ class ActiveTypeVars a where
179179
atv :: a -> Set.Set TVar
180180

181181
instance ActiveTypeVars Constraint where
182-
atv (EqConst t1 t2) = ftv t1 `Set.union` ftv t2
183-
atv (ImpInstConst t1 ms t2) =
184-
ftv t1 `Set.union` (ftv ms `Set.intersection` ftv t2)
185-
atv (ExpInstConst t s) = ftv t `Set.union` ftv s
182+
atv (EqConst t1 t2 ) = ftv t1 `Set.union` ftv t2
183+
atv (ImpInstConst t1 ms t2) = ftv t1 `Set.union` (ftv ms `Set.intersection` ftv t2)
184+
atv (ExpInstConst t s ) = ftv t `Set.union` ftv s
186185

187186
instance ActiveTypeVars a => ActiveTypeVars [a] where
188187
atv = foldr (Set.union . atv) mempty
@@ -213,9 +212,8 @@ instance Monoid InferError where
213212
mempty = TypeInferenceAborted
214213
mappend = (<>)
215214

216-
---------------------------------------------------------------------------------
215+
217216
-- * Inference
218-
---------------------------------------------------------------------------------
219217

220218
-- | Run the inference monad
221219
runInfer' :: MonadInfer m => InferT s m a -> m (Either InferError a)
@@ -535,57 +533,75 @@ instance MonadInfer m => MonadEval (Judgment s) (InferT s m) where
535533
(tv :~> t)
536534

537535
evalAbs (ParamSet ps variadic _mname) k = do
538-
js <- fmap concat $ forM ps $ \(name, _) -> do
539-
tv <- fresh
540-
pure [(name, tv)]
541-
542-
let (env, tys) =
543-
(\f -> foldl' f (As.empty, mempty) js) $ \(as1, t1) (k, t) ->
544-
(as1 `As.merge` As.singleton k t, M.insert k t t1)
545-
arg = pure $ Judgment env mempty (TSet True tys)
546-
call = k arg $ \args b -> (args, ) <$> b
547-
names = fmap fst js
536+
js <-
537+
concat <$>
538+
traverse
539+
(\(name, _) ->
540+
do
541+
tv <- fresh
542+
pure [(name, tv)]
543+
)
544+
ps
545+
546+
let
547+
(env, tys) =
548+
(\f -> foldl' f (As.empty, mempty) js) $ \(as1, t1) (k, t) ->
549+
(as1 `As.merge` As.singleton k t, M.insert k t t1)
550+
arg = pure $ Judgment env mempty (TSet True tys)
551+
call = k arg $ \args b -> (args, ) <$> b
552+
names = fmap fst js
548553

549554
(args, Judgment as cs t) <- foldr (\(_, TVar a) -> extendMSet a) call js
550555

551556
ty <- TSet variadic <$> traverse (inferredType <$>) args
552557

553-
pure $ Judgment
554-
(foldl' As.remove as names)
555-
(cs <> [ EqConst t' (tys M.! x) | x <- names, t' <- As.lookup x as ])
556-
(ty :~> t)
558+
pure $
559+
Judgment
560+
(foldl' As.remove as names)
561+
(cs <> [ EqConst t' (tys M.! x) | x <- names, t' <- As.lookup x as ])
562+
(ty :~> t)
557563

558564
evalError = throwError . EvaluationError
559565

560-
data Judgment s = Judgment
566+
data Judgment s =
567+
Judgment
561568
{ assumptions :: As.Assumption
562569
, typeConstraints :: [Constraint]
563570
, inferredType :: Type
564571
}
565572
deriving Show
566573

567-
instance Monad m => FromValue NixString (InferT s m) (Judgment s) where
574+
instance
575+
Monad m
576+
=> FromValue NixString (InferT s m) (Judgment s)
577+
where
568578
fromValueMay _ = pure mempty
569579
fromValue _ = error "Unused"
570580

571-
instance MonadInfer m
572-
=> FromValue (AttrSet (Judgment s), AttrSet SourcePos)
573-
(InferT s m) (Judgment s) where
574-
fromValueMay (Judgment _ _ (TSet _ xs)) = do
575-
let sing _ = Judgment As.empty mempty
576-
pure $ pure (M.mapWithKey sing xs, mempty)
581+
instance
582+
MonadInfer m
583+
=> FromValue ( AttrSet (Judgment s)
584+
, AttrSet SourcePos
585+
) (InferT s m) (Judgment s)
586+
where
587+
fromValueMay (Judgment _ _ (TSet _ xs)) =
588+
do
589+
let sing _ = Judgment As.empty mempty
590+
pure $ pure (M.mapWithKey sing xs, mempty)
577591
fromValueMay _ = pure mempty
578-
fromValue = fromValueMay >=>
579-
pure . fromMaybe
592+
fromValue =
593+
pure .
594+
fromMaybe
580595
(mempty, mempty)
596+
<=< fromValueMay
581597

582598
instance MonadInfer m
583599
=> ToValue (AttrSet (Judgment s), AttrSet SourcePos)
584600
(InferT s m) (Judgment s) where
585601
toValue (xs, _) =
586602
Judgment
587603
<$> foldrM go As.empty xs
588-
<*> (concat <$> traverse ((pure . typeConstraints) <=< demand ) xs)
604+
<*> (concat <$> traverse ((pure . typeConstraints) <=< demand) xs)
589605
<*> (TSet True <$> traverse ((pure . inferredType) <=< demand) xs)
590606
where
591607
go x rest =
@@ -636,13 +652,14 @@ normalizeScheme (Forall _ body) = Forall (fmap snd ord) (normtype body)
636652
normtype (TSet b a) = TSet b (M.map normtype a)
637653
normtype (TList a ) = TList (fmap normtype a)
638654
normtype (TMany ts) = TMany (fmap normtype ts)
639-
normtype (TVar a ) = case Prelude.lookup a ord of
640-
Just x -> TVar x
641-
Nothing -> error "type variable not in signature"
655+
normtype (TVar a ) =
656+
maybe
657+
(error "type variable not in signature")
658+
TVar
659+
(Prelude.lookup a ord)
660+
642661

643-
---------------------------------------------------------------------------------
644662
-- * Constraint Solver
645-
---------------------------------------------------------------------------------
646663

647664
newtype Solver m a = Solver (LogicT (StateT [TypeError] m) a)
648665
deriving (Functor, Applicative, Alternative, Monad, MonadPlus,
@@ -743,7 +760,9 @@ solve cs = solve' (nextSolvable cs)
743760
s' <- lift $ instantiate s
744761
solve (EqConst t s' : cs)
745762

746-
instance Monad m => Scoped (Judgment s) (InferT s m) where
763+
instance
764+
Monad m
765+
=> Scoped (Judgment s) (InferT s m) where
747766
currentScopes = currentScopesReader
748767
clearScopes = clearScopesReader @(InferT s m) @(Judgment s)
749768
pushScopes = pushScopesReader

0 commit comments

Comments
 (0)