Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,14 @@ typeCheckPrimOp op = case op of

typeCheckPrimHof :: Typer m => PrimHof (Atom i) -> m i o (Type o)
typeCheckPrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of
Map f -> getTypeE f
Map fun array -> do
Pi (PiType (PiBinder b argTy PlainArrow) Pure resEltTy) <- getTypeE fun
let resEltTy' = ignoreHoistFailure $ hoist b resEltTy
TabPi (TabPiType binder argEltTy) <- getTypeE array
let argEltTy' = ignoreHoistFailure $ hoist binder argEltTy
checkAlphaEq argTy argEltTy'
refreshAbs (Abs binder UnitE) \binder' _ ->
return $ TabPi $ TabPiType binder' (sink resEltTy')
For _ ixDict f -> do
ixTy <- ixTyFromDict =<< substM ixDict
Pi (PiType (PiBinder b argTy PlainArrow) eff eltTy) <- getTypeE f
Expand Down
16 changes: 8 additions & 8 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -560,17 +560,17 @@ toImpHof :: Emits o => Maybe (Dest o) -> PrimHof (Atom i) -> SubstImpM i o (Atom
toImpHof maybeDest hof = do
resultTy <- getTypeSubst (Hof hof)
case hof of
Map (TabLam (TabLamExpr (b:>ixTy) body)) -> do
-- TODO: The following code block is identical to the `For` case below.
-- Reuse the code for the `For` case by generating `For (Lam ...)`, with
-- suitable `Lam ...`, when currently a `Map (TabLam ...)` is generated.
Map (Lam (LamExpr b body)) array -> do
rDest <- allocDest maybeDest resultTy
ixTy' <- substM ixTy
n <- indexSetSizeImp ixTy'
TabPi (TabPiType (_:>ixTy) _) <- getTypeSubst array
array' <- substM array
n <- indexSetSizeImp ixTy
emitLoop noHint Fwd n \i -> do
idx <- unsafeFromOrdinalImp (sink ixTy') i
idx <- unsafeFromOrdinalImp (sink ixTy) i
ithArg <- dropSubst $ translateExpr Nothing $
TabApp (sink array') $ idx :| []
ithDest <- destGet (sink rDest) idx
void $ extendSubst (b @> SubstVal idx) $
void $ extendSubst (b @> SubstVal ithArg) $
translateBlock (Just ithDest) body
destToAtom rDest
For d ixDict (Lam (LamExpr b body)) -> do
Expand Down
92 changes: 30 additions & 62 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,20 @@ getImplicitArg (PiBinder _ argTy arr) = case arr of
return $ Just $ Con $ DictHole (AlwaysEqual ctx) argTy
_ -> return Nothing

etaExpand :: EmitsInf n => Atom n -> InfererM i n (Atom n)
etaExpand fun = do
ty <- getType fun
case ty of
Pi (PiType (PiBinder b argTy arr) eff _) -> do
case fun of
Lam _ -> pure fun
_ -> buildLamInf noHint arr argTy
(\b' -> applySubst (b @> b') eff)
(\x -> do
Distinct <- getDistinct
app (sink fun) (Var x))
_ -> error "atom must have pi type"

checkOrInferRho :: forall i o. EmitsBoth o
=> UExpr i -> RequiredTy RhoType o -> InfererM i o (Atom o)
checkOrInferRho (WithSrcE pos expr) reqTy = do
Expand All @@ -944,68 +958,22 @@ checkOrInferRho (WithSrcE pos expr) reqTy = do
ixTy <- asIxType ty'
matchRequirement $ TabLam $ TabLamExpr (b':>ixTy) body'
UMap fun array -> do
argElemVar <- liftM Var $ freshInferenceName (TC TypeKind)
resElemVar <- liftM Var $ freshInferenceName (TC TypeKind)
funTy <- naryNonDepPiType PlainArrow Pure [argElemVar] resElemVar
fun' <- checkOrInferRho fun (Check funTy)

arrayReqTy <- case reqTy of
Check (TabPi (TabPiType (b:>ixTy) resElemTy)) -> do
-- TODO: Throw a graceful error if `resElemTy` depends on `b`.
let resElemTy' = ignoreHoistFailure $ hoist b resElemTy
constrainEq resElemVar resElemTy'
liftM (Check . TabPi) $ nonDepTabPiType ixTy argElemVar
Check _ -> return Infer
Infer -> return Infer
array' <- checkOrInferRho array arrayReqTy

-- Construct the `TabLam` for `Map`. This should probably not be done here
-- alongside the inference code; perhaps `AbstractSyntax.hs` would be a
-- better place for this. However, if we replaced replaced the `fun` and
-- `array` arguments to `UMap` with a `UTabLam`, this would make typing
-- failures less informative at the source code level (since the source code
-- contains concrete syntax for `fun` and `array`, but not for the
-- constructed `UTabLam`.) The cleanest alternative, however, would probably
-- be to place the construction of the `TabLam` (or of an equivalent block
-- paired with binders for its free variables) in `Lower.hs` or `Imp.hs`.
-- However, at that point we would require the expressions in the block
-- (including the expressions in the decls inside the block) to be
-- simplified; and the `TabApp` and `App` expressions below are apparently
-- not simplified.
TabPi (TabPiType (bA:>ixTy) argElemTy) <- getType array'
-- TODO: Throw a graceful error if `argElemTy` depends on `bA`.
let argElemTy' = ignoreHoistFailure $ hoist bA argElemTy
-- NOTE: In the definition of `arrayReqTy` we have already introduced a
-- constaint for `resElemVar`; but we still need to constrain `argElemVar`.
-- (Additionally `resElemVar` should also be constrained by the
-- `matchRequirement` below, but this is not the case for `argElemVar`.)
constrainEq argElemVar argElemTy'
Pi (PiType (PiBinder bF _ _) _ resElemTy) <- getType fun'
-- TODO: Throw a graceful error if `resElemTy` depends on `bF`.
let resElemTy' = ignoreHoistFailure $ hoist bF resElemTy

f <- withFreshBinder noHint ixTy \b0 -> do
let binder = b0:>ixTy
-- I am having trouble getting the following to work when trying to use a
-- single call to `withFreshBinders` (plural!) only. The problem appears
-- to be that `withFreshBinders` does not make available evidence of
-- `Distinct` for the intermediate scope, i.e. the scope that has `b1` but
-- not `b2`. Without this evidence, `sink fun'` in the let-block below is
-- not valid.
body <- withFreshBinder noHint (sink argElemTy') \b1 ->
withFreshBinder noHint (sink resElemTy') \b2 ->
let indexName = binderName b0
argElem = TabApp (sink array') $ (Var indexName) :| []
declArgElem = Let b1 (DeclBinding PlainLet (sink argElemTy') argElem)
funApp = App (sink fun') $ (Var $ binderName b1) :| []
declResElem = Let b2 (DeclBinding PlainLet (sink resElemTy') funApp)
ann = BlockAnn (sink resElemTy') Pure
block = Block ann (Nest declArgElem (Nest declResElem Empty)) (Var $ binderName b2)
in return block
return $ TabLam $ TabLamExpr binder body

result <- liftM Var $ emit $ Hof $ Map f
matchRequirement result
array' <- inferRho array
arrayTy <- getType array'
case arrayTy of
TabPi (TabPiType (b:>_) argElemTy) -> do
argElemTy' <- case hoist b argElemTy of
HoistSuccess ty -> return ty
HoistFailure _ -> throw TypeErr "expected non-dependent array type"
resElemVar <- liftM Var $ freshInferenceName (TC TypeKind)
funTy <- naryNonDepPiType PlainArrow Pure [argElemTy'] resElemVar
fun' <- checkOrInferRho fun (Check funTy)
-- Eta-expand `fun'` into a `Lam`. Later on we make use of the invariant
-- that the first argument of `Map` is a `Lam`.
fun'' <- etaExpand fun'
result <- liftM Var $ emit $ Hof $ Map fun'' array'
matchRequirement result
_ -> throw TypeErr "expected array type"
UFor dir (UForExpr b body) -> do
allowedEff <- getAllowedEffects
let uLamExpr = ULamExpr PlainArrow b body
Expand Down
9 changes: 7 additions & 2 deletions src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,12 @@ getTypePrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of
Pi (PiType (PiBinder b _ _) _ eltTy) <- getTypeE f
ixTy <- ixTyFromDict =<< substM dict
return $ TabTy (b:>ixTy) eltTy
Map f -> getTypeE f
Map fun array -> do
Pi (PiType (PiBinder b _ _) _ resEltTy) <- getTypeE fun
let resEltTy' = ignoreHoistFailure $ hoist b resEltTy
TabPi (TabPiType binder _) <- getTypeE array
refreshAbs (Abs binder UnitE) \binder' _ ->
return $ TabPi $ TabPiType binder' (sink resEltTy')
While _ -> return UnitTy
Linearize f -> do
Pi (PiType (PiBinder binder a PlainArrow) Pure b) <- getTypeE f
Expand Down Expand Up @@ -798,7 +803,7 @@ exprEffects expr = case expr of
_ -> return Pure
Hof hof -> case hof of
For _ _ f -> functionEffs f
Map _ -> return Pure
Map _ _ -> return Pure
While body -> functionEffs body
Linearize _ -> return Pure -- Body has to be a pure function
Transpose _ -> return Pure -- Body has to be a pure function
Expand Down
16 changes: 13 additions & 3 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,19 @@ projectDictMethod d i = do

simplifyHof :: Emits o => Hof i -> SimplifyM i o (Atom o)
simplifyHof hof = case hof of
Map f -> do
f' <- simplifyAtom f
liftM Var $ emit $ Hof $ Map f'
Map fun array -> do
(fun', Abs b recon) <- simplifyLam fun
array' <- simplifyAtom array
ans <- liftM Var $ emit $ Hof $ Map fun' array'
case recon of
IdentityRecon -> return ans
LamRecon reconAbs -> do
TabPi (TabPiType (_:>ixTy) _) <- getType array'
buildTabLam noHint ixTy \i -> do
locals <- tabApp (sink ans) $ Var i
ithArg <- emitAtomToName =<< (tabApp (sink array') $ Var i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of emitting to name, you could b @> SubstVal ithArg below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

reconAbs' <- applySubst (b @> ithArg) reconAbs
applyReconAbs reconAbs' locals
For d ixDict lam -> do
ixTy@(IxType _ ixDict') <- ixTyFromDict =<< substM ixDict
(lam', Abs b recon) <- simplifyLam lam
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Types/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ traversePrimOp = inline traverse

data PrimHof e =
For ForAnn e e -- ix dict, body lambda
| Map e -- body tab-lambda
| Map e e -- lambda, array
| While e
| RunReader e e
| RunWriter (Maybe e) (BaseMonoidP e) e
Expand Down