Skip to content

Commit e05ed56

Browse files
committed
Extend affine type-checking for cover case expressions.
1 parent 19b9c92 commit e05ed56

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

src/lib/CheckType.hs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import Types.Core
3434
import Types.Imp
3535
import Types.Primitives
3636
import Types.Source
37-
import Util (forMZipped_)
3837

3938
-- === top-level API ===
4039

@@ -62,6 +61,7 @@ class ( Monad2 m, Fallible2 m, SubstReader Name m
6261
, EnvReader2 m, EnvExtender2 m)
6362
=> Typer (m::MonadKind2) (r::IR) | m -> r where
6463
affineUsed :: AtomName r o -> m i o ()
64+
parallelAffines_ :: [m i o ()] -> m i o ()
6565

6666
newtype TyperT (m::MonadKind) (r::IR) (i::S) (o::S) (a :: *) =
6767
TyperT { runTyperT' :: SubstReaderT Name (StateT1 (NameMap (AtomNameC r) Int) (EnvReaderT m)) i o a }
@@ -81,18 +81,7 @@ liftTyperT cont =
8181
{-# INLINE liftTyperT #-}
8282

8383
instance Fallible m => Typer (TyperT m r) r where
84-
-- TODO Should be able to use an affine variable in each branch of a `case`,
85-
-- but this abstraction can't capture that. One solution could be to
86-
-- - Add an -- `isolated` operation of type
87-
-- isolated :: m i o () -> m i o <name-usage-map>
88-
-- which doesn't change the state in the monad, but returns the delta that
89-
-- the underlying action tried to add. (Maybe I can even implement this
90-
-- generically if the state is a group?)
91-
-- - Add a `mergeNameMap :: <name-usage-map> -> m i o ()` operation,
92-
-- which would check each key for being used too many times.
93-
-- - Then `case` checks each arm in isolation, zips the maps with maximum,
94-
-- and then calls `mergeNameMap` on the result.
95-
-- I also can't make up my mind whether a `Seq` loop should be allowed to
84+
-- I can't make up my mind whether a `Seq` loop should be allowed to
9685
-- close over a dest from an enclosing scope. Status quo permits this.
9786
affineUsed name = TyperT $ do
9887
affines <- get
@@ -102,6 +91,24 @@ instance Fallible m => Typer (TyperT m r) r where
10291
else
10392
put $ insertNameMap name (n + 1) affines
10493
Nothing -> put $ insertNameMap name 1 affines
94+
parallelAffines_ actions = TyperT $ do
95+
-- This method permits using an affine variable in each branch of a `case`.
96+
-- We check each `case` branch in isolation, detecting affine overuse within
97+
-- the branch; then we check whether the union of the variables used in the
98+
-- branches reuses a variable from outside that it shouldn't.
99+
-- This has the down-side of localizing such an error to the case rather
100+
-- than to the offending in-branch use, but that can be improved later.
101+
affines <- get
102+
isolateds <- forM actions \act -> do
103+
put mempty
104+
runTyperT' act
105+
get
106+
put affines
107+
forM_ (toListNameMap $ unionsWithNameMap max isolateds) \(name, ct) ->
108+
case ct of
109+
0 -> return ()
110+
1 -> runTyperT' $ affineUsed name
111+
_ -> error $ "Unexpected multi-used affine name " ++ show name ++ " from case branches."
105112

106113
-- === typeable things ===
107114

@@ -721,8 +728,8 @@ checkCase :: (Typer m r, IRRep r) => Atom r i -> [Alt r i] -> Type r o -> Effect
721728
checkCase scrut alts resultTy effs = do
722729
scrutTy <- getTypeE scrut
723730
altsBinderTys <- checkCaseAltsBinderTys scrutTy
724-
forMZipped_ alts altsBinderTys \alt bs ->
725-
checkAlt resultTy bs effs alt
731+
parallelAffines_ $ zipWith (\alt bs ->
732+
checkAlt resultTy bs effs alt) alts altsBinderTys
726733

727734
checkCaseAltsBinderTys :: (Fallible1 m, EnvReader m, IRRep r) => Type r n -> m n [Type r n]
728735
checkCaseAltsBinderTys ty = case ty of

src/lib/Name.hs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import qualified Data.HashMap.Strict as HM
2525
import qualified Data.Map.Strict as M
2626
import Data.Bits
2727
import Data.Functor ((<&>))
28-
import Data.Foldable (toList)
28+
import Data.Foldable (toList, foldl')
2929
import Data.Maybe (fromJust, catMaybes)
3030
import Data.Hashable
3131
import Data.Kind (Type)
@@ -3215,6 +3215,11 @@ unionWithNameMap f (UnsafeNameMap raw1) (UnsafeNameMap raw2) =
32153215
UnsafeNameMap $ R.unionWith f raw1 raw2
32163216
{-# INLINE unionWithNameMap #-}
32173217

3218+
unionsWithNameMap :: (Foldable f) => (a -> a -> a) -> f (NameMap c a n) -> NameMap c a n
3219+
unionsWithNameMap func maps =
3220+
foldl' (unionWithNameMap func) mempty maps
3221+
{-# INLINE unionsWithNameMap #-}
3222+
32183223
traverseNameMap :: (Applicative f) => (a -> f b)
32193224
-> NameMap c a n -> f (NameMap c b n)
32203225
traverseNameMap f (UnsafeNameMap raw) = UnsafeNameMap <$> traverse f raw

0 commit comments

Comments
 (0)