@@ -62,15 +62,15 @@ toImpFunction cc lam = do
6262 RefTy _ ansTy -> allocDestUnmanaged =<< substM ansTy
6363 _ -> error " Expected a reference type for body destination"
6464 extendSubst (destb @> SubstVal (destToAtom dest)) do
65- void $ translateBlock Nothing body
65+ void $ translateBlock body
6666 resultAtom <- loadAtom dest
6767 repValToList <$> atomToRepVal resultAtom
6868 _ -> do
6969 (argAtoms, resultDest) <- interpretImpArgsWithCC cc (sink ty) vs
7070 extendSubst (bs @@> (SubstVal <$> argAtoms)) do
7171 (DestBlock destb body) <- return bodyAbs
7272 extendSubst (destb @> SubstVal (destToAtom (sink resultDest))) do
73- void $ translateBlock Nothing body
73+ void $ translateBlock body
7474 return []
7575
7676getNaryLamImpArgTypesWithCC :: EnvReader m
@@ -266,8 +266,12 @@ liftImpM cont = do
266266-- === the actual pass ===
267267
268268translateBlock :: forall i o . Emits o
269- => MaybeDest o -> SBlock i -> SubstImpM i o (SAtom o )
270- translateBlock dest (Block _ decls result) =
269+ => SBlock i -> SubstImpM i o (SAtom o )
270+ translateBlock (Block _ decls result) = translateDeclNest decls $ substM result
271+
272+ translateBlockWithDest :: forall i o . Emits o
273+ => Dest o -> SBlock i -> SubstImpM i o (SAtom o )
274+ translateBlockWithDest dest (Block _ decls result) =
271275 translateDeclNest decls $ translateAtom dest result
272276
273277translateDeclNestSubst
@@ -287,6 +291,12 @@ translateDeclNest decls cont = do
287291 withSubst s' cont
288292{-# INLINE translateDeclNest #-}
289293
294+ translateAtom :: forall i o . Emits o
295+ => Dest o -> SAtom i -> SubstImpM i o (SAtom o )
296+ translateAtom dest x = do
297+ x' <- substM x
298+ storeAtom dest x' >> return x'
299+
290300translateExpr :: forall i o . Emits o => SExpr i -> SubstImpM i o (SAtom o )
291301translateExpr expr = confuseGHC >>= \ _ -> case expr of
292302 Hof hof -> toImpHof hof
@@ -313,7 +323,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
313323 case trySelectBranch e' of
314324 Just (con, arg) -> do
315325 Abs b body <- return $ alts !! con
316- extendSubst (b @> SubstVal arg) $ translateBlock Nothing body
326+ extendSubst (b @> SubstVal arg) $ translateBlock body
317327 Nothing -> do
318328 RepVal sumTy (Branch (tag: xss)) <- atomToRepVal e'
319329 ts <- caseAltsBinderTys sumTy
@@ -323,11 +333,11 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
323333 where
324334 go tag xss = do
325335 tag' <- fromScalarAtom tag
326- dest <- maybeAllocDest Nothing =<< substM ty
336+ dest <- allocDest =<< substM ty
327337 emitSwitch tag' (zip xss alts) $
328338 \ (xs, Abs b body) ->
329339 void $ extendSubst (b @> SubstVal (sink xs)) $
330- translateBlock ( Just $ sink dest) body
340+ translateBlockWithDest ( sink dest) body
331341 loadAtom dest
332342 DAMOp damOp -> case damOp of
333343 Seq d ixDict carry f -> do
@@ -338,12 +348,12 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
338348 emitLoop (getNameHint b) d n \ i -> do
339349 idx <- unsafeFromOrdinalImp (sink ixTy) i
340350 void $ extendSubst (b @> SubstVal (PairVal idx (sink carry'))) $
341- translateBlock Nothing body
351+ translateBlock body
342352 return carry'
343353 RememberDest d f -> do
344354 UnaryLamExpr b body <- return f
345355 d' <- substM d
346- void $ extendSubst (b @> SubstVal d') $ translateBlock Nothing body
356+ void $ extendSubst (b @> SubstVal d') $ translateBlock body
347357 return d'
348358 Place ref val -> do
349359 val' <- substM val
@@ -356,21 +366,13 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
356366 TabCon _ ty rows -> do
357367 resultTy@ (TabPi (TabPiType b _)) <- substM ty
358368 let ixTy = binderAnn b
359- dest <- maybeAllocDest Nothing resultTy
369+ dest <- allocDest resultTy
360370 forM_ (zip [0 .. ] rows) \ (i, row) -> do
361371 row' <- substM row
362372 ithDest <- indexDest dest =<< unsafeFromOrdinalImp ixTy (IIdxRepVal i)
363373 storeAtom ithDest row'
364374 loadAtom dest
365375
366- translateAtom :: forall i o . Emits o
367- => MaybeDest o -> SAtom i -> SubstImpM i o (SAtom o )
368- translateAtom maybeDest x = do
369- x' <- substM x
370- case maybeDest of
371- Nothing -> return x'
372- Just dest -> storeAtom dest x' >> return x'
373-
374376toImpRefOp :: Emits o
375377 => SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o )
376378toImpRefOp refDest' m = do
@@ -387,7 +389,7 @@ toImpRefOp refDest' m = do
387389 MPut x -> storeAtom refDest x >> return UnitVal
388390 MGet -> do
389391 Dest resultTy _ <- return refDest
390- dest <- maybeAllocDest Nothing resultTy
392+ dest <- allocDest resultTy
391393 -- It might be more efficient to implement a specialized copy for dests
392394 -- than to go through a general purpose atom.
393395 storeAtom dest =<< loadAtom refDest
@@ -483,7 +485,7 @@ toImpMiscOp op = case op of
483485 ShowAny _ -> error " Shouldn't have ShowAny in simplified IR"
484486 ShowScalar x -> do
485487 resultTy <- getType $ PrimOp $ MiscOp op
486- Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@ (Leaf tabPtr)]) <- maybeAllocDest Nothing resultTy
488+ Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@ (Leaf tabPtr)]) <- allocDest resultTy
487489 xScalar <- fromScalarAtom x
488490 size <- emitInstr $ IShowScalar tabPtr xScalar
489491 let size' = toScalarAtom size
@@ -529,12 +531,12 @@ toImpFor
529531toImpFor resultTy d ixDict (UnaryLamExpr b body) = do
530532 ixTy <- ixTyFromDict =<< substM ixDict
531533 n <- indexSetSizeImp ixTy
532- dest <- maybeAllocDest Nothing resultTy
534+ dest <- allocDest resultTy
533535 emitLoop (getNameHint b) d n \ i -> do
534536 idx <- unsafeFromOrdinalImp (sink ixTy) i
535537 ithDest <- indexDest (sink dest) idx
536538 void $ extendSubst (b @> SubstVal idx) $
537- translateBlock ( Just ithDest) body
539+ translateBlockWithDest ithDest body
538540 loadAtom dest
539541toImpFor _ _ _ _ = error " expected a lambda as the atom argument"
540542
@@ -545,7 +547,7 @@ toImpHof hof = do
545547 For d ixDict lam -> toImpFor resultTy d ixDict lam
546548 While body -> do
547549 body' <- buildBlockImp do
548- ans <- fromScalarAtom =<< translateBlock Nothing body
550+ ans <- fromScalarAtom =<< translateBlock body
549551 return [ans]
550552 emitStatement $ IWhile body'
551553 return UnitVal
@@ -555,14 +557,14 @@ toImpHof hof = do
555557 rDest <- allocDest =<< getType r'
556558 storeAtom rDest r'
557559 extendSubst (h @> SubstVal (Con HeapVal ) <.> ref @> SubstVal (destToAtom rDest)) $
558- translateBlock Nothing body
560+ translateBlock body
559561 RunWriter d (BaseMonoid e _) f -> do
560562 BinaryLamExpr h ref body <- return f
561563 let PairTy ansTy accTy = resultTy
562564 (aDest, wDest) <- case d of
563- Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
565+ Nothing -> destPairUnpack <$> allocDest resultTy
564566 Just d' -> do
565- aDest <- maybeAllocDest Nothing ansTy
567+ aDest <- allocDest ansTy
566568 wDest <- atomToDest =<< substM d'
567569 return (aDest, wDest)
568570 e' <- substM e
@@ -571,23 +573,23 @@ toImpHof hof = do
571573 liftMonoidEmpty accTy' e''
572574 storeAtom wDest emptyVal
573575 void $ extendSubst (h @> SubstVal (Con HeapVal ) <.> ref @> SubstVal (destToAtom wDest)) $
574- translateBlock ( Just aDest) body
576+ translateBlockWithDest aDest body
575577 PairVal <$> loadAtom aDest <*> loadAtom wDest
576578 RunState d s f -> do
577579 BinaryLamExpr h ref body <- return f
578580 let PairTy ansTy _ = resultTy
579581 (aDest, sDest) <- case d of
580- Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
582+ Nothing -> destPairUnpack <$> allocDest resultTy
581583 Just d' -> do
582- aDest <- maybeAllocDest Nothing ansTy
584+ aDest <- allocDest ansTy
583585 sDest <- atomToDest =<< substM d'
584586 return (aDest, sDest)
585587 storeAtom sDest =<< substM s
586588 void $ extendSubst (h @> SubstVal (Con HeapVal ) <.> ref @> SubstVal (destToAtom sDest)) $
587- translateBlock ( Just aDest) body
589+ translateBlockWithDest aDest body
588590 PairVal <$> loadAtom aDest <*> loadAtom sDest
589- RunIO body-> translateBlock Nothing body
590- RunInit body -> translateBlock Nothing body
591+ RunIO body-> translateBlock body
592+ RunInit body -> translateBlock body
591593 where
592594 liftMonoidEmpty :: Emits n => SType n -> SAtom n -> SBuilderM n (SAtom n )
593595 liftMonoidEmpty accTy x = do
@@ -610,8 +612,6 @@ data Dest (n::S) = Dest
610612 (Tree (IExpr n )) -- underlying scalar values
611613 deriving (Show )
612614
613- type MaybeDest n = Maybe (Dest n )
614-
615615data LeafType n where
616616 LeafType :: TypeCtx SimpIR n l -> BaseType -> LeafType n
617617
@@ -972,10 +972,6 @@ allocDestUnmanaged = allocDestWithAllocContext Unmanaged
972972allocDest :: Emits n => SType n -> SubstImpM i n (Dest n )
973973allocDest = allocDestWithAllocContext Managed
974974
975- maybeAllocDest :: Emits n => Maybe (Dest n ) -> SType n -> SubstImpM i n (Dest n )
976- maybeAllocDest (Just d) _ = return d
977- maybeAllocDest Nothing t = allocDest t
978-
979975storeAtom :: Emits n => Dest n -> SAtom n -> SubstImpM i n ()
980976storeAtom dest x = storeRepVal dest =<< atomToRepVal x
981977
@@ -1199,7 +1195,7 @@ emitCall
11991195 -> ImpFunName n -> [SAtom n ] -> SubstImpM i n (SAtom n )
12001196emitCall (PiType bs _ resultTy) f xs = do
12011197 resultTy' <- applySubst (bs @@> map SubstVal xs) resultTy
1202- dest <- maybeAllocDest Nothing resultTy'
1198+ dest <- allocDest resultTy'
12031199 argsImp <- forM xs \ x -> repValToList <$> atomToRepVal x
12041200 destImp <- repValToList <$> atomToRepVal (destToAtom dest)
12051201 let impArgs = concat argsImp ++ destImp
@@ -1415,7 +1411,7 @@ indexSetSizeImp (IxType _ dict) = do
14151411appSpecializedIxMethod :: Emits n => LamExpr SimpIR n -> [SAtom n ] -> SubstImpM i n (SAtom n )
14161412appSpecializedIxMethod simpLam args = do
14171413 LamExpr bs body <- return simpLam
1418- dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock Nothing body
1414+ dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock body
14191415
14201416-- === Abstracting link-time objects ===
14211417
0 commit comments