Skip to content

Commit ab43029

Browse files
committed
Implement destination passing for Case expressions.
This way we won't have to worry about SoA-converting the return value of a Case if we should later pull SoA-conversion out of Imp into its own pass.
1 parent e05ed56 commit ab43029

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

src/lib/Builder.hs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,15 @@ emitRunReader hint r body = do
864864
lam <- buildEffLam hint rTy \h ref -> body h ref
865865
emitExpr $ PrimOp $ Hof $ RunReader r lam
866866

867+
buildRememberDest :: (Emits n, ScopableBuilder SimpIR m)
868+
=> NameHint -> SAtom n
869+
-> (forall l. (Emits l, Distinct l, DExt n l) => SAtomName l -> m l (SAtom l))
870+
-> m n (SAtom n)
871+
buildRememberDest hint dest cont = do
872+
ty <- getType dest
873+
doit <- buildUnaryLamExpr hint ty cont
874+
emitExpr $ PrimOp $ DAMOp $ RememberDest dest doit
875+
867876
-- === vector space (ish) type class ===
868877

869878
zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n)

src/lib/Imp.hs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,11 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
304304
repValAtom =<< naryIndexRepVal f (toList xs)
305305
Atom x -> substM x
306306
PrimOp op -> toImpOp op
307-
Case e alts ty _ -> do
307+
Case e alts unitResultTy _ -> do
308308
e' <- substM e
309+
case unitResultTy of
310+
UnitTy -> return ()
311+
_ -> error $ "Unexpected returning Case in Imp " ++ pprint expr
309312
case trySelectBranch e' of
310313
Just (con, arg) -> do
311314
Abs b body <- return $ alts !! con
@@ -319,12 +322,11 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
319322
where
320323
go tag xss = do
321324
tag' <- fromScalarAtom tag
322-
dest <- allocDest =<< substM ty
323325
emitSwitch tag' (zip xss alts) $
324326
\(xs, Abs b body) ->
325327
extendSubst (b @> SubstVal (sink xs)) $
326-
translateBlock body >>= storeAtom (sink dest)
327-
loadAtom dest
328+
void $ translateBlock body
329+
return UnitVal
328330
TabCon _ _ _ -> error "Unexpected `TabCon` in Imp pass."
329331

330332
toImpRefOp :: Emits o

src/lib/Lower.hs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import PPrint
3535
import QueryType
3636
import Types.Core
3737
import Types.Primitives
38-
import Util (enumerate)
38+
import Util (enumerate, foldMapM)
3939

4040
-- === For loop resolution ===
4141

@@ -109,6 +109,7 @@ lowerExpr :: Emits o => SExpr i -> LowerM i o (SAtom o)
109109
lowerExpr expr = emitExpr =<< case expr of
110110
TabCon Nothing ty els -> lowerTabCon Nothing ty els
111111
PrimOp (Hof (For dir ixDict body)) -> lowerFor Nothing dir ixDict body
112+
Case e alts ty _ -> lowerCase Nothing e alts ty
112113
_ -> visitGeneric expr
113114

114115
lowerBlock :: Emits o => SBlock i -> LowerM i o (SAtom o)
@@ -151,7 +152,6 @@ lowerTabCon maybeDest tabTy elems = do
151152
dest <- case maybeDest of
152153
Just d -> return d
153154
Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest tabTy'
154-
destTy <- getType dest
155155
Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do
156156
buildBlock $ unsafeFromOrdinal (sink ixTy') $ Var $ sink ord
157157
-- This is emitting a chain of RememberDest ops to force `dest` to be used
@@ -163,16 +163,36 @@ lowerTabCon maybeDest tabTy elems = do
163163
let go incoming_dest [] = return incoming_dest
164164
go incoming_dest ((ord, e):rest) = do
165165
i <- dropSubst $ extendSubst (bord@>SubstVal (IdxRepVal (fromIntegral ord))) $
166-
lowerBlock ufoBlock
167-
do_one_elt <- buildUnaryLamExpr "dest" destTy \local_dest -> do
166+
lowerBlock ufoBlock
167+
carried_dest <- buildRememberDest "dest" incoming_dest \local_dest -> do
168168
idest <- indexRef (Var local_dest) (sink i)
169169
place (FullDest idest) =<< visitAtom e
170170
return UnitVal
171-
carried_dest <- emitExpr $ PrimOp $ DAMOp $ RememberDest incoming_dest do_one_elt
172171
go carried_dest rest
173172
dest' <- go dest (enumerate elems)
174173
return $ PrimOp $ DAMOp $ Freeze dest'
175174

175+
lowerCase :: Emits o
176+
=> Maybe (Dest SimpIR o) -> SAtom i -> [Alt SimpIR i] -> SType i
177+
-> LowerM i o (SExpr o)
178+
lowerCase maybeDest scrut alts resultTy = do
179+
resultTy' <- substM resultTy
180+
dest <- case maybeDest of
181+
Just d -> return d
182+
Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest resultTy'
183+
scrut' <- visitAtom scrut
184+
dest' <- buildRememberDest "case_dest" dest \local_dest -> do
185+
alts' <- forM alts \(Abs (b:>ty) body) -> do
186+
ty' <- substM ty
187+
buildAbs (getNameHint b) ty' \b' ->
188+
extendSubst (b @> Rename b') $
189+
buildBlock do
190+
lowerBlockWithDest (Var $ sink $ local_dest) body $> UnitVal
191+
eff' <- foldMapM getEffects alts'
192+
void $ emitExpr $ Case (sink scrut') alts' UnitTy eff'
193+
return UnitVal
194+
return $ PrimOp $ DAMOp $ Freeze dest'
195+
176196
-- Destination-passing traversals
177197
--
178198
-- The idea here is to try to reuse the memory already allocated for outputs of surrounding
@@ -199,6 +219,9 @@ data ProjDest o
199219
= FullDest (Dest SimpIR o)
200220
| ProjDest (NE.NonEmpty Projection) (Dest SimpIR o) -- dest corresponds to the projection applied to name
201221

222+
instance SinkableE ProjDest where
223+
sinkingProofE = todoSinkableProof
224+
202225
lookupDest :: DestAssignment i' o -> SAtomName i' -> Maybe (ProjDest o)
203226
lookupDest = flip lookupNameMap
204227

@@ -258,6 +281,13 @@ lowerExprWithDest dest expr = case expr of
258281
PrimOp (Hof (RunState Nothing s body)) -> traverseRWS body \ref' body' -> do
259282
s' <- visitAtom s
260283
return $ RunState ref' s' body'
284+
Case e alts ty _ -> case dest of
285+
Nothing -> lowerCase Nothing e alts ty
286+
Just (FullDest d) -> lowerCase (Just d) e alts ty
287+
Just d -> do
288+
ans <- lowerCase Nothing e alts ty >>= emitExprToAtom
289+
place d ans
290+
return $ Atom ans
261291
_ -> generic
262292
where
263293
tabDest = dest <&> \case FullDest d -> d; ProjDest _ _ -> error "unexpected projection"

0 commit comments

Comments
 (0)