@@ -35,7 +35,7 @@ import PPrint
3535import QueryType
3636import Types.Core
3737import Types.Primitives
38- import Util (enumerate )
38+ import Util (enumerate , foldMapM )
3939
4040-- === For loop resolution ===
4141
@@ -109,6 +109,7 @@ lowerExpr :: Emits o => SExpr i -> LowerM i o (SAtom o)
109109lowerExpr expr = emitExpr =<< case expr of
110110 TabCon Nothing ty els -> lowerTabCon Nothing ty els
111111 PrimOp (Hof (For dir ixDict body)) -> lowerFor Nothing dir ixDict body
112+ Case e alts ty _ -> lowerCase Nothing e alts ty
112113 _ -> visitGeneric expr
113114
114115lowerBlock :: Emits o => SBlock i -> LowerM i o (SAtom o )
@@ -151,7 +152,6 @@ lowerTabCon maybeDest tabTy elems = do
151152 dest <- case maybeDest of
152153 Just d -> return d
153154 Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest tabTy'
154- destTy <- getType dest
155155 Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ ord -> do
156156 buildBlock $ unsafeFromOrdinal (sink ixTy') $ Var $ sink ord
157157 -- This is emitting a chain of RememberDest ops to force `dest` to be used
@@ -163,16 +163,36 @@ lowerTabCon maybeDest tabTy elems = do
163163 let go incoming_dest [] = return incoming_dest
164164 go incoming_dest ((ord, e): rest) = do
165165 i <- dropSubst $ extendSubst (bord@> SubstVal (IdxRepVal (fromIntegral ord))) $
166- lowerBlock ufoBlock
167- do_one_elt <- buildUnaryLamExpr " dest" destTy \ local_dest -> do
166+ lowerBlock ufoBlock
167+ carried_dest <- buildRememberDest " dest" incoming_dest \ local_dest -> do
168168 idest <- indexRef (Var local_dest) (sink i)
169169 place (FullDest idest) =<< visitAtom e
170170 return UnitVal
171- carried_dest <- emitExpr $ PrimOp $ DAMOp $ RememberDest incoming_dest do_one_elt
172171 go carried_dest rest
173172 dest' <- go dest (enumerate elems)
174173 return $ PrimOp $ DAMOp $ Freeze dest'
175174
175+ lowerCase :: Emits o
176+ => Maybe (Dest SimpIR o ) -> SAtom i -> [Alt SimpIR i ] -> SType i
177+ -> LowerM i o (SExpr o )
178+ lowerCase maybeDest scrut alts resultTy = do
179+ resultTy' <- substM resultTy
180+ dest <- case maybeDest of
181+ Just d -> return d
182+ Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest resultTy'
183+ scrut' <- visitAtom scrut
184+ dest' <- buildRememberDest " case_dest" dest \ local_dest -> do
185+ alts' <- forM alts \ (Abs (b:> ty) body) -> do
186+ ty' <- substM ty
187+ buildAbs (getNameHint b) ty' \ b' ->
188+ extendSubst (b @> Rename b') $
189+ buildBlock do
190+ lowerBlockWithDest (Var $ sink $ local_dest) body $> UnitVal
191+ eff' <- foldMapM getEffects alts'
192+ void $ emitExpr $ Case (sink scrut') alts' UnitTy eff'
193+ return UnitVal
194+ return $ PrimOp $ DAMOp $ Freeze dest'
195+
176196-- Destination-passing traversals
177197--
178198-- The idea here is to try to reuse the memory already allocated for outputs of surrounding
@@ -199,6 +219,9 @@ data ProjDest o
199219 = FullDest (Dest SimpIR o )
200220 | ProjDest (NE. NonEmpty Projection ) (Dest SimpIR o ) -- dest corresponds to the projection applied to name
201221
222+ instance SinkableE ProjDest where
223+ sinkingProofE = todoSinkableProof
224+
202225lookupDest :: DestAssignment i' o -> SAtomName i' -> Maybe (ProjDest o )
203226lookupDest = flip lookupNameMap
204227
@@ -258,6 +281,13 @@ lowerExprWithDest dest expr = case expr of
258281 PrimOp (Hof (RunState Nothing s body)) -> traverseRWS body \ ref' body' -> do
259282 s' <- visitAtom s
260283 return $ RunState ref' s' body'
284+ Case e alts ty _ -> case dest of
285+ Nothing -> lowerCase Nothing e alts ty
286+ Just (FullDest d) -> lowerCase (Just d) e alts ty
287+ Just d -> do
288+ ans <- lowerCase Nothing e alts ty >>= emitExprToAtom
289+ place d ans
290+ return $ Atom ans
261291 _ -> generic
262292 where
263293 tabDest = dest <&> \ case FullDest d -> d; ProjDest _ _ -> error " unexpected projection"
0 commit comments