11module QueryType (
22 getType , getTypeSubst , HasType ,
3- depPairLeftTy ,
3+ getEffects , getEffectsSubst ,
4+ computeAbsEffects , declNestEffects ,
5+ depPairLeftTy , extendEffect ,
46 getAppType , getTabAppType , getMethodType , getBaseMonoidType , getReferentTy ,
57 getMethodIndex ,
68 instantiateDataDef , instantiateDepPairTy , instantiatePi , instantiateTabPi ,
79 litType , lamExprTy ,
810 numNaryPiArgs , naryLamExprType ,
9- projectLength , sourceNameType ,
11+ oneEffect , projectLength , sourceNameType ,
1012 ) where
1113
1214import Control.Monad
1315import Data.Foldable (toList )
1416import Data.List (elemIndex )
1517import qualified Data.List.NonEmpty as NE
1618import qualified Data.Map.Strict as M
19+ import qualified Data.Set as S
1720
1821import CheapReduction (cheapNormalize )
1922import Err
@@ -37,12 +40,28 @@ getType :: (EnvReader m, HasType e) => e n -> m n (Type n)
3740getType e = liftTypeQueryM idSubst $ getTypeE e
3841{-# INLINE getType #-}
3942
40- -- === Exposed helpers for querying types ===
43+ -- === Querying effects ===
44+
45+ getEffects :: (EnvReader m , HasEffectsE e ) => e n -> m n (EffectRow n )
46+ getEffects e = liftTypeQueryM idSubst $ getEffectsImpl e
47+ {-# INLINE getEffects #-}
48+
49+ getEffectsSubst :: (EnvReader2 m , SubstReader AtomSubstVal m , HasEffectsE e )
50+ => e i -> m i o (EffectRow o )
51+ getEffectsSubst e = do
52+ subst <- getSubst
53+ liftTypeQueryM subst $ getEffectsImpl e
54+ {-# INLINE getEffectsSubst #-}
55+
56+ -- === Exposed helpers for querying types and effects ===
4157
4258depPairLeftTy :: DepPairType n -> Type n
4359depPairLeftTy (DepPairType (_:> ty) _) = ty
4460{-# INLINE depPairLeftTy #-}
4561
62+ extendEffect :: Effect n -> EffectRow n -> EffectRow n
63+ extendEffect eff (EffectRow effs t) = EffectRow (S. insert eff effs) t
64+
4665getAppType :: EnvReader m => Type n -> [Atom n ] -> m n (Type n )
4766getAppType f xs = liftTypeQueryM idSubst $ typeApp f xs
4867{-# INLINE getAppType #-}
@@ -162,6 +181,33 @@ sourceNameType v = do
162181 UClassVar v' -> lookupEnv v' >>= \ case ClassBinding def -> return $ getClassTy def
163182 UMethodVar v' -> lookupEnv v' >>= \ case MethodBinding _ _ e -> getType e
164183
184+ oneEffect :: Effect n -> EffectRow n
185+ oneEffect eff = EffectRow (S. singleton eff) Nothing
186+
187+ -- === computing effects ===
188+
189+ computeAbsEffects :: (EnvExtender m , SubstE Name e )
190+ => Abs (Nest Decl ) e n -> m n (Abs (Nest Decl ) (EffectRow `PairE ` e ) n )
191+ computeAbsEffects it = refreshAbs it \ decls result -> do
192+ effs <- declNestEffects decls
193+ return $ Abs decls (effs `PairE ` result)
194+ {-# INLINE computeAbsEffects #-}
195+
196+ declNestEffects :: (EnvReader m ) => Nest Decl n l -> m l (EffectRow l )
197+ declNestEffects decls = liftEnvReaderM $ declNestEffectsRec decls mempty
198+ {-# INLINE declNestEffects #-}
199+
200+ declNestEffectsRec :: Nest Decl n l -> EffectRow l -> EnvReaderM l (EffectRow l )
201+ declNestEffectsRec Empty ! acc = return acc
202+ declNestEffectsRec n@ (Nest decl rest) ! acc = withExtEvidence n do
203+ expr <- sinkM $ declExpr decl
204+ deff <- getEffects expr
205+ acc' <- sinkM $ acc <> deff
206+ declNestEffectsRec rest acc'
207+ where
208+ declExpr :: Decl n l -> Expr n
209+ declExpr (Let _ (DeclBinding _ _ expr)) = expr
210+
165211-- === implementation of querying types ===
166212
167213newtype TypeQueryM (i :: S ) (o :: S ) (a :: * ) = TypeQueryM {
@@ -642,3 +688,89 @@ getClassTy (ClassDef _ _ bs _ _) = go bs
642688 go Empty = TyKind
643689 go (Nest (b:> ty) rest) = Pi $ PiType (PiBinder b ty PlainArrow ) Pure $ go rest
644690
691+ -- === querying effects implementation ===
692+
693+ class HasEffectsE (e :: E ) where
694+ getEffectsImpl :: e i -> TypeQueryM i o (EffectRow o )
695+
696+ instance HasEffectsE Expr where
697+ getEffectsImpl = exprEffects
698+ {-# INLINE getEffectsImpl #-}
699+
700+ exprEffects :: Expr i -> TypeQueryM i o (EffectRow o )
701+ exprEffects expr = case expr of
702+ Atom _ -> return Pure
703+ App f xs -> do
704+ fTy <- getTypeSubst f
705+ case fromNaryPiType (length xs) fTy of
706+ Just (NaryPiType bs effs _) -> do
707+ xs' <- mapM substM xs
708+ let subst = bs @@> fmap SubstVal xs'
709+ applySubst subst effs
710+ Nothing -> error $
711+ " Not a " ++ show (length xs + 1 ) ++ " -argument pi type: " ++ pprint expr
712+ TabApp _ _ -> return Pure
713+ Op op -> case op of
714+ PrimEffect ref m -> do
715+ getTypeSubst ref >>= \ case
716+ RefTy (Var h) _ ->
717+ return $ case m of
718+ MGet -> oneEffect (RWSEffect State $ Just h)
719+ MPut _ -> oneEffect (RWSEffect State $ Just h)
720+ MAsk -> oneEffect (RWSEffect Reader $ Just h)
721+ -- XXX: We don't verify the base monoid. See note about RunWriter.
722+ MExtend _ _ -> oneEffect (RWSEffect Writer $ Just h)
723+ _ -> error " References must have reference type"
724+ ThrowException _ -> return $ oneEffect ExceptionEffect
725+ IOAlloc _ _ -> return $ oneEffect IOEffect
726+ IOFree _ -> return $ oneEffect IOEffect
727+ PtrLoad _ -> return $ oneEffect IOEffect
728+ PtrStore _ _ -> return $ oneEffect IOEffect
729+ _ -> return Pure
730+ Hof hof -> case hof of
731+ For _ f -> functionEffs f
732+ -- The tiled and scalar bodies should have the same effects, but
733+ -- that's checked elsewhere. If they are the same, merging them
734+ -- with <> is a noop.
735+ Tile _ tiled scalar -> liftM2 (<>) (functionEffs tiled) $ functionEffs scalar
736+ While body -> functionEffs body
737+ Linearize _ -> return Pure -- Body has to be a pure function
738+ Transpose _ -> return Pure -- Body has to be a pure function
739+ RunWriter _ f -> rwsFunEffects Writer f
740+ RunReader _ f -> rwsFunEffects Reader f
741+ RunState _ f -> rwsFunEffects State f
742+ PTileReduce _ _ _ -> return mempty
743+ RunIO f -> do
744+ effs <- functionEffs f
745+ return $ deleteEff IOEffect effs
746+ CatchException f -> do
747+ effs <- functionEffs f
748+ return $ deleteEff ExceptionEffect effs
749+ Case _ _ _ effs -> substM effs
750+
751+ instance HasEffectsE Block where
752+ getEffectsImpl (Block (BlockAnn _ effs) _ _) = substM effs
753+ getEffectsImpl (Block NoBlockAnn _ _) = return Pure
754+ {-# INLINE getEffectsImpl #-}
755+
756+ instance HasEffectsE Alt where
757+ getEffectsImpl (Abs bs body) =
758+ substBinders bs \ bs' ->
759+ ignoreHoistFailure . hoist bs' <$> getEffectsImpl body
760+ {-# INLINE getEffectsImpl #-}
761+
762+ functionEffs :: Atom i -> TypeQueryM i o (EffectRow o )
763+ functionEffs f = getTypeSubst f >>= \ case
764+ Pi (PiType b effs _) -> return $ ignoreHoistFailure $ hoist b effs
765+ _ -> error " Expected a function type"
766+
767+ rwsFunEffects :: RWS -> Atom i -> TypeQueryM i o (EffectRow o )
768+ rwsFunEffects rws f = getTypeSubst f >>= \ case
769+ BinaryFunTy h ref effs _ -> do
770+ let effs' = ignoreHoistFailure $ hoist ref effs
771+ let effs'' = deleteEff (RWSEffect rws (Just (binderName h))) effs'
772+ return $ ignoreHoistFailure $ hoist h effs''
773+ _ -> error " Expected a binary function type"
774+
775+ deleteEff :: Effect n -> EffectRow n -> EffectRow n
776+ deleteEff eff (EffectRow effs t) = EffectRow (S. delete eff effs) t
0 commit comments