Skip to content

Commit 360cd31

Browse files
authored
Merge pull request #1318 from google-research/delivering-even-more-decls
More prep for decls-in-binders (again)
2 parents cfab914 + d13ee9c commit 360cd31

File tree

10 files changed

+168
-140
lines changed

10 files changed

+168
-140
lines changed

src/lib/Builder.hs

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -626,31 +626,14 @@ buildAbs hint binding cont = do
626626
return $ Abs b body
627627
{-# INLINE buildAbs #-}
628628

629-
varsAsBinderNest :: (EnvReader m, IRRep r) => [AtomVar r n] -> m n (EmptyAbs (Nest (Binder r)) n)
630-
varsAsBinderNest [] = return $ EmptyAbs Empty
631-
varsAsBinderNest (v:vs) = do
632-
rest <- varsAsBinderNest vs
633-
ty <- return $ getType v
634-
let AtomVar v' _ = v
635-
Abs b (Abs bs UnitE) <- return $ abstractFreeVar v' rest
636-
return $ EmptyAbs (Nest (b:>ty) bs)
637-
638629
typesFromNonDepBinderNest
639630
:: (EnvReader m, Fallible1 m, IRRep r)
640631
=> Nest (Binder r) n l -> m n [Type r n]
641632
typesFromNonDepBinderNest Empty = return []
642-
typesFromNonDepBinderNest (Nest (b:>ty) rest) = do
643-
Abs rest' UnitE <- return $ ignoreHoistFailure $ hoist b (Abs rest UnitE)
633+
typesFromNonDepBinderNest (Nest b rest) = do
634+
Abs rest' UnitE <- return $ assumeConst $ Abs (UnaryNest b) $ Abs rest UnitE
644635
tys <- typesFromNonDepBinderNest rest'
645-
return $ ty : tys
646-
647-
singletonBinderNest
648-
:: (EnvReader m, IRRep r)
649-
=> NameHint -> ann n
650-
-> m n (EmptyAbs (Nest (BinderP (AtomNameC r) ann)) n)
651-
singletonBinderNest hint ann = do
652-
Abs b _ <- return $ newName hint
653-
return $ EmptyAbs (Nest (b:>ann) Empty)
636+
return $ binderType b : tys
654637

655638
buildUnaryLamExpr
656639
:: (ScopableBuilder r m)
@@ -858,7 +841,7 @@ zeroAt ty = liftEmitBuilder $ go ty where
858841
BaseTy bt -> return $ Con $ Lit $ zeroLit bt
859842
ProdTy tys -> ProdVal <$> mapM go tys
860843
TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i ->
861-
go =<< instantiateTabPiTy (sink tabPi) (Var i)
844+
go =<< instantiate (sink tabPi) [Var i]
862845
_ -> unreachable
863846
zeroLit bt = case bt of
864847
Scalar Float64Type -> Float64Lit 0.0
@@ -1134,8 +1117,7 @@ naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] ->
11341117
naryTopAppInlined f xs = do
11351118
TopFunBinding f' <- lookupEnv f
11361119
case f' of
1137-
DexTopFun _ (TopLam _ _ (LamExpr bs body)) _ ->
1138-
applySubst (bs@@>(SubstVal<$>xs)) body >>= emitBlock
1120+
DexTopFun _ lam _ -> instantiate lam xs >>= emitBlock
11391121
_ -> naryTopApp f xs
11401122
{-# INLINE naryTopAppInlined #-}
11411123

@@ -1194,8 +1176,7 @@ applyIxMethod dict method args = case dict of
11941176
IxDictSpecialized _ d params -> do
11951177
SpecializedDict _ maybeFs <- lookupSpecDict d
11961178
Just fs <- return maybeFs
1197-
TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method
1198-
emitBlock =<< applySubst (bs @@> fmap SubstVal (params ++ args)) body
1179+
instantiate (fs !! fromEnum method) (params ++ args) >>= emitBlock
11991180

12001181
unsafeFromOrdinal :: (SBuilder m, Emits n) => IxType SimpIR n -> Atom SimpIR n -> m n (Atom SimpIR n)
12011182
unsafeFromOrdinal (IxType _ dict) i = applyIxMethod dict UnsafeFromOrdinal [i]
@@ -1262,10 +1243,10 @@ isJustE x = liftEmitBuilder $
12621243
-- Monoid a -> (n=>a) -> a
12631244
reduceE :: (Emits n, Builder r m) => BaseMonoid r n -> Atom r n -> m n (Atom r n)
12641245
reduceE monoid xs = liftEmitBuilder do
1265-
TabTy d (n:>ty) a <- return $ getType xs
1266-
a' <- return $ ignoreHoistFailure $ hoist n a
1267-
getSnd =<< emitRunWriter noHint a' monoid \_ ref ->
1268-
buildFor noHint Fwd (sink $ IxType ty d) \i -> do
1246+
TabPi tabPi <- return $ getType xs
1247+
let a = assumeConst tabPi
1248+
getSnd =<< emitRunWriter noHint a monoid \_ ref ->
1249+
buildFor noHint Fwd (sink $ tabIxType tabPi) \i -> do
12691250
x <- tabApp (sink xs) (Var i)
12701251
emitExpr $ PrimOp $ RefOp (sink $ Var ref) $ MExtend (sink monoid) x
12711252

@@ -1278,11 +1259,10 @@ andMonoid = liftM (BaseMonoid TrueAtom) $ liftBuilder $
12781259
mapE :: (Emits n, ScopableBuilder r m)
12791260
=> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l))
12801261
-> Atom r n -> m n (Atom r n)
1281-
mapE f xs = do
1282-
TabTy d (n:>ty) _ <- return $ getType xs
1283-
buildFor (getNameHint n) Fwd (IxType ty d) \i -> do
1284-
x <- tabApp (sink xs) (Var i)
1285-
f x
1262+
mapE cont xs = do
1263+
TabPi tabPi <- return $ getType xs
1264+
buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> do
1265+
tabApp (sink xs) (Var i) >>= cont
12861266

12871267
-- (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) =
12881268
catMaybesE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n)

src/lib/CheapReduction.hs

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ module CheapReduction
1111
( CheaplyReducibleE (..), cheapReduce, cheapReduceWithDecls, cheapNormalize
1212
, normalizeProj, asNaryProj, normalizeNaryProj
1313
, depPairLeftTy, instantiateTyConDef
14-
, dataDefRep, instantiateDepPairTy, unwrapNewtypeType, repValAtom
14+
, dataDefRep, unwrapNewtypeType, repValAtom
1515
, unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType
1616
, liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..)
1717
, visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2
18-
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy
19-
, bindersToVars, bindersToAtoms)
18+
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated
19+
, bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst)
2020
where
2121

2222
import Control.Applicative
@@ -242,7 +242,7 @@ cheapReduceDictExpr resultTy d = case d of
242242
args' <- mapM cheapReduceE args
243243
InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName
244244
let InstanceBody superclasses _ = body
245-
applySubst (bs@@>(SubstVal <$> args')) (superclasses !! superclassIx)
245+
instantiate (Abs bs (superclasses !! superclassIx)) args'
246246
child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx
247247
InstantiatedGiven f xs ->
248248
reduceApp <|> justSubst
@@ -285,19 +285,16 @@ instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where
285285
cheapReduceE dict >>= \case
286286
DictCon _ (InstanceDict instanceName args) -> dropSubst do
287287
args' <- mapM cheapReduceE args
288-
InstanceDef _ _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName
289-
let method = methods !! i
290-
extendSubst (bs@@>(SubstVal <$> args')) do
291-
method' <- cheapReduceE method
288+
def <- lookupInstanceDef instanceName
289+
withInstantiated def args' \(PairE _ (InstanceBody _ methods)) -> do
290+
method' <- cheapReduceE $ methods !! i
292291
cheapReduceApp method' explicitArgs'
293292
_ -> empty
294293
_ -> empty
295294

296295
cheapReduceApp :: CAtom o -> [CAtom o] -> CheapReducerM CoreIR i o (CAtom o)
297296
cheapReduceApp f xs = case f of
298-
Lam (CoreLamExpr _ (LamExpr bs body)) -> do
299-
let subst = bs @@> fmap SubstVal xs
300-
dropSubst $ extendSubst subst $ cheapReduceE body
297+
Lam lam -> dropSubst $ withInstantiated lam xs \body -> cheapReduceE body
301298
_ -> empty
302299

303300
instance IRRep r => CheaplyReducibleE r (IxType r) (IxType r) where
@@ -450,7 +447,7 @@ projType i ty x = case ty of
450447
DepPairTy t | i == 0 -> return $ depPairLeftTy t
451448
DepPairTy t | i == 1 -> do
452449
xFst <- normalizeProj (ProjectProduct 0) x
453-
instantiateDepPairTy t xFst
450+
instantiate t [xFst]
454451
_ -> error $ "Can't project type: " ++ pprint ty
455452

456453
unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n)
@@ -470,13 +467,39 @@ instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do
470467
applySubst (bs @@> (SubstVal <$> xs)) conDefs
471468
{-# INLINE instantiateTyConDef #-}
472469

473-
instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (EffTy r n)
474-
instantiatePiTy (PiType bs effTy) xs = do
475-
applySubst (bs @@> (SubstVal <$> xs)) effTy
476-
477-
instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n)
478-
instantiateTabPiTy (TabPiType _ b resultTy) x = do
479-
applySubst (b @> SubstVal x) resultTy
470+
assumeConst
471+
:: (IRRep r, HoistableE body, SinkableE body, ToBindersAbs e body r) => e n -> body n
472+
assumeConst e = case toAbs e of Abs bs body -> ignoreHoistFailure $ hoist bs body
473+
474+
instantiate
475+
:: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r)
476+
=> e n -> [Atom r n] -> m n (body n)
477+
instantiate e xs = case toAbs e of
478+
Abs bs body -> applySubst (bs @@> (SubstVal <$> xs)) body
479+
480+
-- "lazy" subst-extending version of `instantiate`
481+
withInstantiated
482+
:: (SubstReader AtomSubstVal m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r)
483+
=> e i -> [Atom r o]
484+
-> (forall i'. body i' -> m i' o a)
485+
-> m i o a
486+
withInstantiated e xs cont = case toAbs e of
487+
Abs bs body -> extendSubst (bs @@> (SubstVal <$> xs)) $ cont body
488+
489+
instantiateNames
490+
:: (EnvReader m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r)
491+
=> e n -> [AtomName r n] -> m n (body n)
492+
instantiateNames e vs = case toAbs e of
493+
Abs bs body -> applyRename (bs @@> vs) body
494+
495+
-- "lazy" subst-extending version of `instantiateNames`
496+
withInstantiatedNames
497+
:: (SubstReader Name m, IRRep r, RenameE body, SinkableE body, ToBindersAbs e body r)
498+
=> e i -> [AtomName r o]
499+
-> (forall i'. body i' -> m i' o a)
500+
-> m i o a
501+
withInstantiatedNames e vs cont = case toAbs e of
502+
Abs bs body -> extendRenamer (bs @@> vs) $ cont body
480503

481504
-- Returns a representation type (type of an TypeCon-typed Newtype payload)
482505
-- given a list of instantiated DataConDefs.
@@ -498,10 +521,6 @@ makeStructRepVal tyConName args = do
498521
_ -> error "wrong number of args"
499522
_ -> return $ ProdVal args
500523

501-
instantiateDepPairTy :: (IRRep r, EnvReader m) => DepPairType r n -> Atom r n -> m n (Type r n)
502-
instantiateDepPairTy (DepPairType _ b rhsTy) x = applyAbs (Abs b rhsTy) (SubstVal x)
503-
{-# INLINE instantiateDepPairTy #-}
504-
505524
-- === traversable terms ===
506525

507526
class Monad m => NonAtomRenamer m i o | m -> i, m -> o where

src/lib/CheckType.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ instance IRRep r => HasType r (Atom r) where
214214
DepPair l r ty -> do
215215
ty' <- checkTypeE TyKind ty
216216
l' <- checkTypeE (depPairLeftTy ty') l
217-
rTy <- instantiateDepPairTy ty' l'
217+
rTy <- instantiate ty' [l']
218218
r |: rTy
219219
return $ DepPairTy ty'
220220
Con con -> typeCheckPrimCon con
@@ -236,7 +236,7 @@ instance IRRep r => HasType r (Atom r) where
236236
DepPairTy t | i == 1 -> do
237237
x' <- renameM x
238238
xFst <- normalizeProj (ProjectProduct 0) x'
239-
instantiateDepPairTy t xFst
239+
instantiate t [xFst]
240240
_ -> throw TypeErr $ "Not a product type:" ++ pprint ty
241241
TypeAsAtom ty -> getTypeE ty
242242

@@ -275,7 +275,7 @@ instance IRRep r => HasType r (Type r) where
275275
DepPairTy t | i == 1 -> do
276276
x' <- renameM x
277277
xFst <- normalizeProj (ProjectProduct 0) x'
278-
instantiateDepPairTy t xFst
278+
instantiate t [xFst]
279279
_ -> throw TypeErr $ "Not a product type:" ++ pprint ty
280280

281281
instance HasType CoreIR SimpInCore where

src/lib/Imp.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ toImpRefOp refDest' m = do
373373
idx <- unsafeFromOrdinalImp (sink ixTy) i
374374
xElt <- liftBuilderImp $ tabApp (sink x) (sink idx)
375375
yElt <- liftBuilderImp $ tabApp (sink y) (sink idx)
376-
eltTy <- instantiateTabPiTy (sink t) idx
376+
eltTy <- instantiate (sink t) [idx]
377377
ithDest <- indexDest (sink accDest) idx
378378
liftMonoidCombine ithDest eltTy (sink bc) xElt yElt
379379
_ -> error $ "Base monoid type mismatch: can't lift " ++
@@ -584,7 +584,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do
584584
emitLoop noHint Fwd n \i -> do
585585
idx <- unsafeFromOrdinalImp (sink ixTy) i
586586
x' <- sinkM x
587-
eltTy <- instantiateTabPiTy (sink t) idx
587+
eltTy <- instantiate (sink t) [idx]
588588
ithDest <- indexDest (sink accDest) idx
589589
liftMonoidEmpty ithDest eltTy x'
590590
_ -> error $ "Base monoid type mismatch: can't lift " ++
@@ -1003,7 +1003,7 @@ buildGarbageVal ty =
10031003

10041004
indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n)
10051005
indexDest (Dest (TabPi tabTy) tree) i = do
1006-
eltTy <- instantiateTabPiTy tabTy i
1006+
eltTy <- instantiate tabTy [i]
10071007
ord <- ordinalImp (tabIxType tabTy) i
10081008
leafTys <- typeToTree $ TabPi tabTy
10091009
Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do
@@ -1027,7 +1027,7 @@ indexRepValParam :: Emits n
10271027
-> (IExpr n -> SubstImpM i n (IExpr n))
10281028
-> SubstImpM i n (SRepVal n)
10291029
indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do
1030-
eltTy <- instantiateTabPiTy tabTy i
1030+
eltTy <- instantiate tabTy [i]
10311031
ord <- ordinalImp (tabIxType tabTy) i
10321032
leafTys <- typeToTree (TabPi tabTy)
10331033
vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do

src/lib/Inference.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of
887887
-- TODO: check for the case that we're given some of the implicit dependent pair args explicitly
888888
lhsVal <- Var <$> freshInferenceName MiscInfVar lhsTy
889889
-- TODO: make an InfVarDesc case for dep pair instantiation
890-
rhsTy <- instantiateDepPairTy depPairTy lhsVal
890+
rhsTy <- instantiate depPairTy [lhsVal]
891891
rhsVal <- checkSigma noHint expr rhsTy
892892
return $ DepPair lhsVal rhsVal depPairTy
893893
_ -> fallback
@@ -996,7 +996,7 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do
996996
case reqTy of
997997
Check (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do
998998
lhs' <- checkSigmaDependent noHint lhs lhsTy
999-
rhsTy <- instantiateDepPairTy ty lhs'
999+
rhsTy <- instantiate ty [lhs']
10001000
rhs' <- checkSigma noHint rhs rhsTy
10011001
return $ DepPair lhs' rhs' ty
10021002
_ -> throw TypeErr $ "Can't infer the type of a dependent pair; please annotate it"

src/lib/Lower.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftE
6767
True -> do
6868
refreshAbs (Abs bs body) \bs' body' -> do
6969
xs <- bindersToAtoms bs'
70-
EffTy _ resultTy <- instantiatePiTy (sink piTy) xs
70+
EffTy _ resultTy <- instantiate (sink piTy) xs
7171
Abs b body'' <- lowerFullySequentialBlock resultTy body'
7272
return $ LamExpr (bs' >>> UnaryNest b) body''
7373
False -> do

0 commit comments

Comments
 (0)