Skip to content

Commit 3fbcc02

Browse files
authored
Merge pull request #1322 from google-research/type-checking-refactor
Update type checker in anticipation of decls-in-types.
2 parents 2580bc9 + 39c45b8 commit 3fbcc02

File tree

11 files changed

+790
-821
lines changed

11 files changed

+790
-821
lines changed

src/lib/CheckType.hs

Lines changed: 672 additions & 799 deletions
Large diffs are not rendered by default.

src/lib/Core.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ newtype EnvReaderT (m::MonadKind) (n::S) (a:: *) =
8282
, MonadWriter w, Fallible, Searcher, Alternative)
8383

8484
type EnvReaderM = EnvReaderT Identity
85+
type FallibleEnvReaderM = EnvReaderT FallibleM
8586

8687
runEnvReaderM :: Distinct n => Env n -> EnvReaderM n a -> a
8788
runEnvReaderM bindings m = runIdentity $ runEnvReaderT bindings m

src/lib/Generalize.hs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ module Generalize (generalizeArgs, generalizeIxDict) where
88

99
import Control.Monad
1010

11-
import CheckType (isData)
1211
import Core
1312
import Err
1413
import Types.Core

src/lib/Imp.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,7 @@ impInstrTypes instr = case instr of
15351535
where hostPtrTy ty = PtrType (CPU, ty)
15361536

15371537
instance CheckableE SimpIR ImpFunction where
1538-
checkE _ = return () -- TODO
1538+
checkE = renameM -- TODO
15391539

15401540
-- TODO: Don't use Core Envs for Imp!
15411541
instance BindsEnv ImpDecl where

src/lib/Inference.hs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
module Inference
1212
( inferTopUDecl, checkTopUType, inferTopUExpr
1313
, trySynthTerm, generalizeDict, asTopBlock
14-
, synthTopE, UDeclInferenceResult (..)) where
14+
, synthTopE, UDeclInferenceResult (..), asFFIFunType) where
1515

1616
import Prelude hiding ((.), id)
1717
import Control.Category
@@ -45,6 +45,7 @@ import SourceInfo
4545
import Subst
4646
import QueryType
4747
import Types.Core
48+
import Types.Imp
4849
import Types.Primitives
4950
import Types.Source
5051
import Util hiding (group)
@@ -3184,6 +3185,43 @@ withFabricatedEmitsInf cont = fromWrapWithEmitsInf
31843185
newtype WrapWithEmitsInf n r =
31853186
WrapWithEmitsInf { fromWrapWithEmitsInf :: EmitsInf n => r }
31863187

3188+
-- === IFunType ===
3189+
3190+
asFFIFunType :: EnvReader m => CType n -> m n (Maybe (IFunType, CorePiType n))
3191+
asFFIFunType ty = return do
3192+
Pi piTy <- return ty
3193+
impTy <- checkFFIFunTypeM piTy
3194+
return (impTy, piTy)
3195+
3196+
checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType
3197+
checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) effTy) = do
3198+
argTy <- checkScalar $ binderType b
3199+
case bs of
3200+
Empty -> do
3201+
resultTys <- checkScalarOrPairType (etTy effTy)
3202+
let cc = case length resultTys of
3203+
0 -> error "Not implemented"
3204+
1 -> FFICC
3205+
_ -> FFIMultiResultCC
3206+
return $ IFunType cc [argTy] resultTys
3207+
Nest b' rest -> do
3208+
let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy
3209+
IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest
3210+
return $ IFunType cc (argTy:argTys) resultTys
3211+
checkFFIFunTypeM _ = error "expected at least one argument"
3212+
3213+
checkScalar :: (IRRep r, Fallible m) => Type r n -> m BaseType
3214+
checkScalar (BaseTy ty) = return ty
3215+
checkScalar ty = throw TypeErr $ pprint ty
3216+
3217+
checkScalarOrPairType :: (IRRep r, Fallible m) => Type r n -> m [BaseType]
3218+
checkScalarOrPairType (PairTy a b) = do
3219+
tys1 <- checkScalarOrPairType a
3220+
tys2 <- checkScalarOrPairType b
3221+
return $ tys1 ++ tys2
3222+
checkScalarOrPairType (BaseTy ty) = return [ty]
3223+
checkScalarOrPairType ty = throw TypeErr $ pprint ty
3224+
31873225
-- === instances ===
31883226

31893227
instance PrettyE e => Pretty (UDeclInferenceResult e l) where
@@ -3197,9 +3235,11 @@ instance SinkableE e => SinkableE (UDeclInferenceResult e) where
31973235

31983236
instance (RenameE e, CheckableE CoreIR e) => CheckableE CoreIR (UDeclInferenceResult e) where
31993237
checkE = \case
3200-
UDeclResultDone _ -> return ()
3201-
UDeclResultBindName _ block _ -> checkE block
3202-
UDeclResultBindPattern _ block _ -> checkE block
3238+
UDeclResultDone e -> UDeclResultDone <$> checkE e
3239+
UDeclResultBindName ann block ab ->
3240+
UDeclResultBindName ann <$> checkE block <*> renameM ab -- TODO: check result
3241+
UDeclResultBindPattern hint block recon ->
3242+
UDeclResultBindPattern hint <$> checkE block <*> renameM recon -- TODO: check recon
32033243

32043244
instance HasType CoreIR InfEmission where
32053245
getType = \case

src/lib/Linearize.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ linearize f x = runPrimalMInit $ linearizeLambdaApp f x
291291
linearizeTopLam :: STopLam n -> [Active] -> DoubleBuilder SimpIR n (STopLam n, STopLam n)
292292
linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do
293293
(primalFun, tangentFun) <- runPrimalMInit $ refreshBinders bs \bs' frag -> extendSubst frag do
294-
let allPrimals = nestToAtomVars bs'
294+
let allPrimals = bindersVars bs'
295295
activeVs <- catMaybes <$> forM (zip actives allPrimals) \(active, v) -> case active of
296296
True -> return $ Just v
297297
False -> return $ Nothing

src/lib/QueryType.hs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case
213213
MethodBinding className i -> do
214214
ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className
215215
refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' absPiTy -> do
216-
let params = Var <$> nestToAtomVars paramBs'
216+
let params = Var <$> bindersVars paramBs'
217217
dictTy <- DictTy <$> dictType (sink className) params
218218
withFreshBinder noHint dictTy \dictB -> do
219219
scDicts <- getSuperclassDicts (Var $ binderVar dictB)
@@ -384,3 +384,47 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where
384384
t:ts -> withFreshBinder noHint (BaseTy t) \b -> do
385385
PiType bs effTy <- go ts
386386
return $ PiType (Nest b bs) effTy
387+
388+
-- === Data constraints ===
389+
390+
isData :: EnvReader m => Type CoreIR n -> m n Bool
391+
isData ty = do
392+
result <- liftEnvReaderT $ withSubstReaderT $ checkDataLike ty
393+
case runFallibleM result of
394+
Success () -> return True
395+
Failure _ -> return False
396+
397+
checkDataLike :: Type CoreIR i -> SubstReaderT Name FallibleEnvReaderM i o ()
398+
checkDataLike ty = case ty of
399+
TyVar _ -> notData
400+
TabPi (TabPiType _ b eltTy) -> do
401+
renameBinders b \_ ->
402+
checkDataLike eltTy
403+
DepPairTy (DepPairType _ b@(_:>l) r) -> do
404+
recur l
405+
renameBinders b \_ -> checkDataLike r
406+
NewtypeTyCon nt -> do
407+
(_, ty') <- unwrapNewtypeType =<< renameM nt
408+
dropSubst $ recur ty'
409+
TC con -> case con of
410+
BaseType _ -> return ()
411+
ProdType as -> mapM_ recur as
412+
SumType cs -> mapM_ recur cs
413+
RefType _ _ -> return ()
414+
HeapType -> return ()
415+
_ -> notData
416+
_ -> notData
417+
where
418+
recur = checkDataLike
419+
notData = throw TypeErr $ pprint ty
420+
421+
checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m ()
422+
checkExtends allowed (EffectRow effs effTail) = do
423+
let (EffectRow allowedEffs allowedEffTail) = allowed
424+
case effTail of
425+
EffectRowTail _ -> assertEq allowedEffTail effTail ""
426+
NoTail -> return ()
427+
forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $
428+
throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++
429+
"\nAllowed: " ++ pprint allowed
430+

src/lib/Simplify.hs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Data.Text.Prettyprint.Doc (Pretty (..), hardline)
1919

2020
import Builder
2121
import CheapReduction
22-
import CheckType (CheckableE (..), isData, checkBlock)
22+
import CheckType
2323
import Core
2424
import Err
2525
import Generalize
@@ -274,9 +274,12 @@ instance SinkableE SimplifiedBlock
274274
instance RenameE SimplifiedBlock
275275
instance HoistableE SimplifiedBlock
276276
instance CheckableE SimpIR SimplifiedBlock where
277-
checkE (SimplifiedBlock block _) =
278-
-- TODO: CheckableE instance for the recon too
279-
void $ checkBlock block
277+
checkE (SimplifiedBlock block recon) = do
278+
block' <- renameM block
279+
effTy <- blockEffTy block' -- TODO: store this in the simplified block instead
280+
block'' <- dropSubst $ checkBlock effTy block'
281+
recon' <- renameM recon -- TODO: CheckableE instance for the recon too
282+
return $ SimplifiedBlock block'' recon'
280283

281284
instance Pretty (SimplifiedBlock n) where
282285
pretty (SimplifiedBlock block recon) =
@@ -286,9 +289,9 @@ instance SinkableE SimplifiedTopLam where
286289
sinkingProofE = todoSinkableProof
287290

288291
instance CheckableE SimpIR SimplifiedTopLam where
289-
checkE (SimplifiedTopLam lam _) = do
292+
checkE (SimplifiedTopLam lam recon) =
290293
-- TODO: CheckableE instance for the recon too
291-
checkE lam
294+
SimplifiedTopLam <$> checkE lam <*> renameM recon
292295

293296
instance Pretty (SimplifiedTopLam n) where
294297
pretty (SimplifiedTopLam lam recon) = pretty lam <> hardline <> pretty recon

src/lib/TopLevel.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ import qualified LLVM.AST
4343

4444
import AbstractSyntax
4545
import Builder
46-
import CheckType ( CheckableE (..), asFFIFunType, checkHasType)
46+
import CheckType ( CheckableE (..), checkTypeIs)
4747
#ifdef DEX_DEBUG
48-
import CheckType (checkTypesM)
48+
import CheckType (checkTypes)
4949
#endif
5050
import Core
5151
import ConcreteSyntax
@@ -316,7 +316,7 @@ evalSourceBlock' mname block = case sbContents block of
316316
_ -> evalUExpr expr
317317
fType <- getType <$> toAtomVar fname'
318318
(nimplicit, nexplicit, linFunTy) <- liftExceptEnvReaderM $ getLinearizationType zeros fType
319-
impl `checkHasType` linFunTy >>= \case
319+
liftEnvReaderT (impl `checkTypeIs` linFunTy) >>= \case
320320
Failure _ -> do
321321
let implTy = getType impl
322322
throw TypeErr $ unlines
@@ -744,7 +744,7 @@ checkPass name cont = do
744744
return result
745745
#ifdef DEX_DEBUG
746746
logTop $ MiscLog $ "Running checks"
747-
checkTypesM result
747+
checkTypes result
748748
logTop $ MiscLog $ "Checks passed"
749749
#else
750750
logTop $ MiscLog $ "Checks skipped (not a debug build)"

src/lib/Types/Core.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -943,12 +943,12 @@ binderType (_:>ty) = ty
943943
binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l
944944
binderVar (b:>ty) = AtomVar (binderName b) (sink ty)
945945

946-
nestToAtomVars :: (Distinct l, Ext n l, IRRep r)
947-
=> Nest (Binder r) n l -> [AtomVar r l]
948-
nestToAtomVars = \case
946+
bindersVars :: (Distinct l, Ext n l, IRRep r)
947+
=> Nest (Binder r) n l -> [AtomVar r l]
948+
bindersVars = \case
949949
Empty -> []
950950
Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $
951-
sink (binderVar b) : nestToAtomVars bs
951+
sink (binderVar b) : bindersVars bs
952952

953953
-- === ToBinding ===
954954

0 commit comments

Comments
 (0)