Skip to content

Commit c98dc0c

Browse files
committed
Move effects code into QueryType and generalize to accept a substituion.
Also rename effectsE to getEffects, for consistency with getType.
1 parent 5ac168f commit c98dc0c

File tree

7 files changed

+145
-135
lines changed

7 files changed

+145
-135
lines changed

src/lib/Builder.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ buildCase scrut resultTy indexedAltBody = do
716716
blk <- buildBlock do
717717
ListE xs' <- sinkM $ ListE xs
718718
indexedAltBody i xs'
719-
eff <- effectsE blk
719+
eff <- getEffects blk
720720
return $ blk `PairE` eff
721721
return (Abs bs' body, ignoreHoistFailure $ hoist bs' eff')
722722
liftM Var $ emit $ Case scrut alts resultTy $ mconcat effs
@@ -751,7 +751,7 @@ buildEffLam rws hint ty body = do
751751
-- Contract the type of the produced function to only mention
752752
-- the effects actually demanded by the body. This is safe because
753753
-- it's immediately consumed by an effect discharge primitive.
754-
effs <- effectsE body'
754+
effs <- getEffects body'
755755
return $ Lam $ LamExpr (LamBinder b ty' PlainArrow effs) body'
756756

757757
buildForAnn
@@ -763,7 +763,7 @@ buildForAnn hint ann ty body = do
763763
lam <- withFreshBinder hint (LamBinding PlainArrow ty) \b -> do
764764
let v = binderName b
765765
body' <- buildBlock $ body $ sink v
766-
effs <- effectsE body'
766+
effs <- getEffects body'
767767
return $ Lam $ LamExpr (LamBinder b ty PlainArrow effs) body'
768768
liftM Var $ emit $ Hof $ For ann lam
769769

src/lib/Inference.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ inferNaryApp fCtx f args = addSrcContext fCtx do
10901090
Just naryPi <- asNaryPiType <$> Pi <$> fromPiType True PlainArrow fTy
10911091
(inferredArgs, remaining) <- inferNaryAppArgs naryPi args
10921092
let appExpr = App f inferredArgs
1093-
addEffects =<< exprEffects appExpr
1093+
addEffects =<< getEffects appExpr
10941094
partiallyApplied <- Var <$> emit appExpr
10951095
case nonEmpty remaining of
10961096
Nothing ->

src/lib/Linearize.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ isTrivialForAD expr = do
147147
trivialTy <- (maybeTangentType <$> getType expr) >>= \case
148148
Nothing -> return False
149149
Just tTy -> isSingletonType tTy
150-
hasActiveEffs <- exprEffects expr >>= \case
150+
hasActiveEffs <- getEffects expr >>= \case
151151
Pure -> return False
152152
-- TODO: Be more precise here, such as checking
153153
-- whether the effects are themselves active.

src/lib/QueryType.hs

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
module QueryType (
22
getType, getTypeSubst, HasType,
3-
depPairLeftTy,
3+
getEffects, getEffectsSubst,
4+
computeAbsEffects, declNestEffects,
5+
depPairLeftTy, extendEffect,
46
getAppType, getTabAppType, getMethodType, getBaseMonoidType, getReferentTy,
57
getMethodIndex,
68
instantiateDataDef, instantiateDepPairTy, instantiatePi, instantiateTabPi,
79
litType, lamExprTy,
810
numNaryPiArgs, naryLamExprType,
9-
projectLength, sourceNameType,
11+
oneEffect, projectLength, sourceNameType,
1012
) where
1113

1214
import Control.Monad
1315
import Data.Foldable (toList)
1416
import Data.List (elemIndex)
1517
import qualified Data.List.NonEmpty as NE
1618
import qualified Data.Map.Strict as M
19+
import qualified Data.Set as S
1720

1821
import CheapReduction (cheapNormalize)
1922
import Err
@@ -37,12 +40,28 @@ getType :: (EnvReader m, HasType e) => e n -> m n (Type n)
3740
getType e = liftTypeQueryM idSubst $ getTypeE e
3841
{-# INLINE getType #-}
3942

40-
-- === Exposed helpers for querying types ===
43+
-- === Querying effects ===
44+
45+
getEffects :: (EnvReader m, HasEffectsE e) => e n -> m n (EffectRow n)
46+
getEffects e = liftTypeQueryM idSubst $ getEffectsImpl e
47+
{-# INLINE getEffects #-}
48+
49+
getEffectsSubst :: (EnvReader2 m, SubstReader AtomSubstVal m, HasEffectsE e)
50+
=> e i -> m i o (EffectRow o)
51+
getEffectsSubst e = do
52+
subst <- getSubst
53+
liftTypeQueryM subst $ getEffectsImpl e
54+
{-# INLINE getEffectsSubst #-}
55+
56+
-- === Exposed helpers for querying types and effects ===
4157

4258
depPairLeftTy :: DepPairType n -> Type n
4359
depPairLeftTy (DepPairType (_:>ty) _) = ty
4460
{-# INLINE depPairLeftTy #-}
4561

62+
extendEffect :: Effect n -> EffectRow n -> EffectRow n
63+
extendEffect eff (EffectRow effs t) = EffectRow (S.insert eff effs) t
64+
4665
getAppType :: EnvReader m => Type n -> [Atom n] -> m n (Type n)
4766
getAppType f xs = liftTypeQueryM idSubst $ typeApp f xs
4867
{-# INLINE getAppType #-}
@@ -162,6 +181,33 @@ sourceNameType v = do
162181
UClassVar v' -> lookupEnv v' >>= \case ClassBinding def -> return $ getClassTy def
163182
UMethodVar v' -> lookupEnv v' >>= \case MethodBinding _ _ e -> getType e
164183

184+
oneEffect :: Effect n -> EffectRow n
185+
oneEffect eff = EffectRow (S.singleton eff) Nothing
186+
187+
-- === computing effects ===
188+
189+
computeAbsEffects :: (EnvExtender m, SubstE Name e)
190+
=> Abs (Nest Decl) e n -> m n (Abs (Nest Decl) (EffectRow `PairE` e) n)
191+
computeAbsEffects it = refreshAbs it \decls result -> do
192+
effs <- declNestEffects decls
193+
return $ Abs decls (effs `PairE` result)
194+
{-# INLINE computeAbsEffects #-}
195+
196+
declNestEffects :: (EnvReader m) => Nest Decl n l -> m l (EffectRow l)
197+
declNestEffects decls = liftEnvReaderM $ declNestEffectsRec decls mempty
198+
{-# INLINE declNestEffects #-}
199+
200+
declNestEffectsRec :: Nest Decl n l -> EffectRow l -> EnvReaderM l (EffectRow l)
201+
declNestEffectsRec Empty !acc = return acc
202+
declNestEffectsRec n@(Nest decl rest) !acc = withExtEvidence n do
203+
expr <- sinkM $ declExpr decl
204+
deff <- getEffects expr
205+
acc' <- sinkM $ acc <> deff
206+
declNestEffectsRec rest acc'
207+
where
208+
declExpr :: Decl n l -> Expr n
209+
declExpr (Let _ (DeclBinding _ _ expr)) = expr
210+
165211
-- === implementation of querying types ===
166212

167213
newtype TypeQueryM (i::S) (o::S) (a :: *) = TypeQueryM {
@@ -642,3 +688,89 @@ getClassTy (ClassDef _ _ bs _ _) = go bs
642688
go Empty = TyKind
643689
go (Nest (b:>ty) rest) = Pi $ PiType (PiBinder b ty PlainArrow) Pure $ go rest
644690

691+
-- === querying effects implementation ===
692+
693+
class HasEffectsE (e::E) where
694+
getEffectsImpl :: e i -> TypeQueryM i o (EffectRow o)
695+
696+
instance HasEffectsE Expr where
697+
getEffectsImpl = exprEffects
698+
{-# INLINE getEffectsImpl #-}
699+
700+
exprEffects :: Expr i -> TypeQueryM i o (EffectRow o)
701+
exprEffects expr = case expr of
702+
Atom _ -> return Pure
703+
App f xs -> do
704+
fTy <- getTypeSubst f
705+
case fromNaryPiType (length xs) fTy of
706+
Just (NaryPiType bs effs _) -> do
707+
xs' <- mapM substM xs
708+
let subst = bs @@> fmap SubstVal xs'
709+
applySubst subst effs
710+
Nothing -> error $
711+
"Not a " ++ show (length xs + 1) ++ "-argument pi type: " ++ pprint expr
712+
TabApp _ _ -> return Pure
713+
Op op -> case op of
714+
PrimEffect ref m -> do
715+
getTypeSubst ref >>= \case
716+
RefTy (Var h) _ ->
717+
return $ case m of
718+
MGet -> oneEffect (RWSEffect State $ Just h)
719+
MPut _ -> oneEffect (RWSEffect State $ Just h)
720+
MAsk -> oneEffect (RWSEffect Reader $ Just h)
721+
-- XXX: We don't verify the base monoid. See note about RunWriter.
722+
MExtend _ _ -> oneEffect (RWSEffect Writer $ Just h)
723+
_ -> error "References must have reference type"
724+
ThrowException _ -> return $ oneEffect ExceptionEffect
725+
IOAlloc _ _ -> return $ oneEffect IOEffect
726+
IOFree _ -> return $ oneEffect IOEffect
727+
PtrLoad _ -> return $ oneEffect IOEffect
728+
PtrStore _ _ -> return $ oneEffect IOEffect
729+
_ -> return Pure
730+
Hof hof -> case hof of
731+
For _ f -> functionEffs f
732+
-- The tiled and scalar bodies should have the same effects, but
733+
-- that's checked elsewhere. If they are the same, merging them
734+
-- with <> is a noop.
735+
Tile _ tiled scalar -> liftM2 (<>) (functionEffs tiled) $ functionEffs scalar
736+
While body -> functionEffs body
737+
Linearize _ -> return Pure -- Body has to be a pure function
738+
Transpose _ -> return Pure -- Body has to be a pure function
739+
RunWriter _ f -> rwsFunEffects Writer f
740+
RunReader _ f -> rwsFunEffects Reader f
741+
RunState _ f -> rwsFunEffects State f
742+
PTileReduce _ _ _ -> return mempty
743+
RunIO f -> do
744+
effs <- functionEffs f
745+
return $ deleteEff IOEffect effs
746+
CatchException f -> do
747+
effs <- functionEffs f
748+
return $ deleteEff ExceptionEffect effs
749+
Case _ _ _ effs -> substM effs
750+
751+
instance HasEffectsE Block where
752+
getEffectsImpl (Block (BlockAnn _ effs) _ _) = substM effs
753+
getEffectsImpl (Block NoBlockAnn _ _) = return Pure
754+
{-# INLINE getEffectsImpl #-}
755+
756+
instance HasEffectsE Alt where
757+
getEffectsImpl (Abs bs body) =
758+
substBinders bs \bs' ->
759+
ignoreHoistFailure . hoist bs' <$> getEffectsImpl body
760+
{-# INLINE getEffectsImpl #-}
761+
762+
functionEffs :: Atom i -> TypeQueryM i o (EffectRow o)
763+
functionEffs f = getTypeSubst f >>= \case
764+
Pi (PiType b effs _) -> return $ ignoreHoistFailure $ hoist b effs
765+
_ -> error "Expected a function type"
766+
767+
rwsFunEffects :: RWS -> Atom i -> TypeQueryM i o (EffectRow o)
768+
rwsFunEffects rws f = getTypeSubst f >>= \case
769+
BinaryFunTy h ref effs _ -> do
770+
let effs' = ignoreHoistFailure $ hoist ref effs
771+
let effs'' = deleteEff (RWSEffect rws (Just (binderName h))) effs'
772+
return $ ignoreHoistFailure $ hoist h effs''
773+
_ -> error "Expected a binary function type"
774+
775+
deleteEff :: Effect n -> EffectRow n -> EffectRow n
776+
deleteEff eff (EffectRow effs t) = EffectRow (S.delete eff effs) t

src/lib/Simplify.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ caseComputingEffs
154154
:: forall m n. (MonadFail1 m, EnvReader m)
155155
=> Atom n -> [Alt n] -> Type n -> m n (Expr n)
156156
caseComputingEffs scrut alts resultTy = do
157-
Case scrut alts resultTy <$> foldMapM effectsE alts
157+
Case scrut alts resultTy <$> foldMapM getEffects alts
158158
{-# INLINE caseComputingEffs #-}
159159

160160
defuncCase :: Emits o => Atom o -> [Alt i] -> Type o -> SimplifyM i o (Atom o)
@@ -702,7 +702,7 @@ exceptToMaybeExpr expr = case expr of
702702

703703
hasExceptions :: (EnvReader m, MonadFail1 m) => Expr n -> m n Bool
704704
hasExceptions expr = do
705-
(EffectRow effs t) <- exprEffects expr
705+
(EffectRow effs t) <- getEffects expr
706706
case t of
707707
Nothing -> return $ ExceptionEffect `S.member` effs
708708
Just _ -> error "Shouldn't have tail left"

src/lib/Transpose.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ substExprIfNonlin expr =
124124
True -> return Nothing
125125
False -> do
126126
expr' <- substNonlin expr
127-
exprEffects expr' >>= isLinEff >>= \case
127+
getEffects expr' >>= isLinEff >>= \case
128128
True -> return Nothing
129129
False -> return $ Just expr'
130130

0 commit comments

Comments
 (0)