@@ -267,15 +267,16 @@ liftImpM cont = do
267267
268268translateBlock :: forall i o . Emits o
269269 => MaybeDest o -> SBlock i -> SubstImpM i o (SAtom o )
270- translateBlock dest (Block _ decls result) = translateDeclNest decls $ translateExpr dest $ Atom result
270+ translateBlock dest (Block _ decls result) =
271+ translateDeclNest decls $ translateAtom dest result
271272
272273translateDeclNestSubst
273274 :: Emits o => Subst AtomSubstVal l o
274275 -> Nest SDecl l i' -> SubstImpM i o (Subst AtomSubstVal i' o )
275276translateDeclNestSubst ! s = \ case
276277 Empty -> return s
277278 Nest (Let b (DeclBinding _ _ expr)) rest -> do
278- x <- withSubst s $ translateExpr Nothing expr
279+ x <- withSubst s $ translateExpr expr
279280 translateDeclNestSubst (s <>> (b@> SubstVal x)) rest
280281
281282translateDeclNest :: Emits o
@@ -286,15 +287,14 @@ translateDeclNest decls cont = do
286287 withSubst s' cont
287288{-# INLINE translateDeclNest #-}
288289
289- translateExpr :: forall i o . Emits o
290- => MaybeDest o -> SExpr i -> SubstImpM i o (SAtom o )
291- translateExpr maybeDest expr = confuseGHC >>= \ _ -> case expr of
292- Hof hof -> toImpHof maybeDest hof
290+ translateExpr :: forall i o . Emits o => SExpr i -> SubstImpM i o (SAtom o )
291+ translateExpr expr = confuseGHC >>= \ _ -> case expr of
292+ Hof hof -> toImpHof hof
293293 TopApp f' xs' -> do
294294 f <- substM f'
295295 xs <- mapM substM xs'
296296 lookupTopFun f >>= \ case
297- DexTopFun _ piTy _ _ -> emitCall maybeDest piTy f $ toList xs
297+ DexTopFun _ piTy _ _ -> emitCall piTy f $ toList xs
298298 FFITopFun _ _ -> do
299299 resultTy <- getType $ TopApp f xs
300300 scalarArgs <- liftM toList $ mapM fromScalarAtom xs
@@ -304,16 +304,16 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
304304 xs <- mapM substM xs'
305305 f <- atomToRepVal =<< substM f'
306306 repValAtom =<< naryIndexRepVal f (toList xs)
307- Atom x -> substM x >>= returnVal
307+ Atom x -> substM x
308308 -- Inlining the traversal helps GHC sink the substM below the case inside toImpOp.
309- PrimOp op -> (inline traversePrimOp) substM op >>= toImpOp maybeDest
310- RefOp refDest eff -> toImpRefOp maybeDest refDest eff
309+ PrimOp op -> (inline traversePrimOp) substM op >>= toImpOp
310+ RefOp refDest eff -> toImpRefOp refDest eff
311311 Case e alts ty _ -> do
312312 e' <- substM e
313313 case trySelectBranch e' of
314314 Just (con, arg) -> do
315315 Abs b body <- return $ alts !! con
316- extendSubst (b @> SubstVal arg) $ translateBlock maybeDest body
316+ extendSubst (b @> SubstVal arg) $ translateBlock Nothing body
317317 Nothing -> do
318318 RepVal sumTy (Branch (tag: xss)) <- atomToRepVal e'
319319 ts <- caseAltsBinderTys sumTy
@@ -323,7 +323,7 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
323323 where
324324 go tag xss = do
325325 tag' <- fromScalarAtom tag
326- dest <- maybeAllocDest maybeDest =<< substM ty
326+ dest <- maybeAllocDest Nothing =<< substM ty
327327 emitSwitch tag' (zip xss alts) $
328328 \ (xs, Abs b body) ->
329329 void $ extendSubst (b @> SubstVal (sink xs)) $
@@ -339,9 +339,7 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
339339 idx <- unsafeFromOrdinalImp (sink ixTy) i
340340 void $ extendSubst (b @> SubstVal (PairVal idx (sink carry'))) $
341341 translateBlock Nothing body
342- case maybeDest of
343- Nothing -> return carry'
344- Just _ -> error " Unexpected dest"
342+ return carry'
345343 RememberDest d f -> do
346344 UnaryLamExpr b body <- return f
347345 d' <- substM d
@@ -350,48 +348,52 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
350348 Place ref val -> do
351349 val' <- substM val
352350 refDest <- atomToDest =<< substM ref
353- storeAtom refDest val' >> returnVal UnitVal
351+ storeAtom refDest val' >> return UnitVal
354352 Freeze ref -> loadAtom =<< atomToDest =<< substM ref
355353 AllocDest ty -> do
356354 d <- liftM destToAtom $ allocDest =<< substM ty
357- returnVal d
355+ return d
358356 TabCon _ ty rows -> do
359357 resultTy@ (TabPi (TabPiType b _)) <- substM ty
360358 let ixTy = binderAnn b
361- dest <- maybeAllocDest maybeDest resultTy
359+ dest <- maybeAllocDest Nothing resultTy
362360 forM_ (zip [0 .. ] rows) \ (i, row) -> do
363361 row' <- substM row
364362 ithDest <- indexDest dest =<< unsafeFromOrdinalImp ixTy (IIdxRepVal i)
365363 storeAtom ithDest row'
366364 loadAtom dest
367- where
368- returnVal atom = case maybeDest of
369- Nothing -> return atom
370- Just dest -> storeAtom dest atom >> return atom
365+
366+ translateAtom :: forall i o . Emits o
367+ => MaybeDest o -> SAtom i -> SubstImpM i o (SAtom o )
368+ translateAtom maybeDest x = do
369+ x' <- substM x
370+ case maybeDest of
371+ Nothing -> return x'
372+ Just dest -> storeAtom dest x' >> return x'
371373
372374toImpRefOp :: Emits o
373- => MaybeDest o -> SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o )
374- toImpRefOp maybeDest refDest' m = do
375+ => SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o )
376+ toImpRefOp refDest' m = do
375377 refDest <- atomToDest =<< substM refDest'
376378 substM m >>= \ case
377- MAsk -> returnVal =<< loadAtom refDest
379+ MAsk -> loadAtom refDest
378380 MExtend (BaseMonoid _ combine) x -> do
379381 xTy <- getType x
380382 refVal <- loadAtom refDest
381383 result <- liftBuilderImp $
382384 liftMonoidCombine (sink xTy) (sink combine) (sink refVal) (sink x)
383385 storeAtom refDest result
384- returnVal UnitVal
385- MPut x -> storeAtom refDest x >> returnVal UnitVal
386+ return UnitVal
387+ MPut x -> storeAtom refDest x >> return UnitVal
386388 MGet -> do
387389 Dest resultTy _ <- return refDest
388- dest <- maybeAllocDest maybeDest resultTy
390+ dest <- maybeAllocDest Nothing resultTy
389391 -- It might be more efficient to implement a specialized copy for dests
390392 -- than to go through a general purpose atom.
391393 storeAtom dest =<< loadAtom refDest
392394 loadAtom dest
393- IndexRef i -> returnVal =<< ( destToAtom <$> indexDest refDest i)
394- ProjRef ~ (ProjectProduct i) -> returnVal $ destToAtom $ projectDest i refDest
395+ IndexRef i -> destToAtom <$> indexDest refDest i
396+ ProjRef ~ (ProjectProduct i) -> return $ destToAtom $ projectDest i refDest
395397 where
396398 liftMonoidCombine
397399 :: Emits n => SType n -> LamExpr SimpIR n
@@ -413,17 +415,14 @@ toImpRefOp maybeDest refDest' m = do
413415 liftMonoidCombine eltTy' (sink bc) xElt yElt
414416 _ -> error $ " Base monoid type mismatch: can't lift " ++
415417 pprint baseTy ++ " to " ++ pprint accTy
416- returnVal atom = case maybeDest of
417- Nothing -> return atom
418- Just dest -> storeAtom dest atom >> return atom
419418
420419toImpOp :: forall i o .
421- Emits o => MaybeDest o -> PrimOp (SAtom o ) -> SubstImpM i o (SAtom o )
422- toImpOp maybeDest op = case op of
420+ Emits o => PrimOp (SAtom o ) -> SubstImpM i o (SAtom o )
421+ toImpOp op = case op of
423422 BinOp binOp x y -> returnIExprVal =<< emitInstr =<< (IBinOp binOp <$> fsa x <*> fsa y)
424423 UnOp unOp x -> returnIExprVal =<< emitInstr =<< (IUnOp unOp <$> fsa x)
425- MemOp memOp -> toImpMemOp maybeDest memOp
426- MiscOp miscOp -> toImpMiscOp maybeDest miscOp
424+ MemOp memOp -> toImpMemOp memOp
425+ MiscOp miscOp -> toImpMiscOp miscOp
427426 VectorOp (VectorBroadcast val vty) -> do
428427 val' <- fsa val
429428 emitInstr (IVectorBroadcast val' $ toIVectorType vty) >>= returnIExprVal
@@ -433,20 +432,17 @@ toImpOp maybeDest op = case op of
433432 refi <- destToAtom <$> indexDest refDest i
434433 refi' <- fsa refi
435434 let PtrType (addrSpace, _) = getIType refi'
436- returnVal =<< case vty of
435+ case vty of
437436 BaseTy vty'@ (Vector _ _) -> do
438437 resultVal <- cast refi' (PtrType (addrSpace, vty'))
439438 repValAtom $ RepVal (RefTy (Con HeapVal ) vty) (Leaf resultVal)
440439 _ -> error " Expected a vector type"
441440 where
442441 fsa = fromScalarAtom
443- returnIExprVal x = returnVal $ toScalarAtom x
444- returnVal atom = case maybeDest of
445- Nothing -> return atom
446- Just dest -> storeAtom dest atom >> return atom
442+ returnIExprVal x = return $ toScalarAtom x
447443
448- toImpMiscOp :: Emits o => MaybeDest o -> MiscOp (SAtom o ) -> SubstImpM i o (SAtom o )
449- toImpMiscOp maybeDest op = case op of
444+ toImpMiscOp :: Emits o => MiscOp (SAtom o ) -> SubstImpM i o (SAtom o )
445+ toImpMiscOp op = case op of
450446 ThrowError resultTy -> do
451447 emitStatement IThrowError
452448 buildGarbageVal resultTy
@@ -465,19 +461,19 @@ toImpMiscOp maybeDest op = case op of
465461 assertEq srcRep destRep $
466462 " representation types don't match: " ++ pprint srcRep ++ " != " ++ pprint destRep
467463 RepVal _ tree <- atomToRepVal x
468- returnVal =<< repValAtom (RepVal resultTy tree)
464+ repValAtom (RepVal resultTy tree)
469465 GarbageVal resultTy -> buildGarbageVal resultTy
470466 Select p x y -> do
471467 BaseTy _ <- getType x
472468 returnIExprVal =<< emitInstr =<< (ISelect <$> fsa p <*> fsa x <*> fsa y)
473469 SumTag con -> case con of
474- Con (SumCon _ tag _) -> returnVal $ TagRepVal $ fromIntegral tag
470+ Con (SumCon _ tag _) -> return $ TagRepVal $ fromIntegral tag
475471 RepValAtom dRepVal -> go dRepVal
476472 _ -> error $ " Not a data constructor: " ++ pprint con
477473 where go dRepVal = do
478474 RepVal _ (Branch (tag: _)) <- return dRepVal
479475 return $ RepValAtom $ RepVal TagRepTy tag
480- ToEnum ty i -> returnVal =<< case ty of
476+ ToEnum ty i -> case ty of
481477 SumTy cases -> do
482478 i' <- fromScalarAtom i
483479 return $ RepValAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [] )) cases
@@ -487,7 +483,7 @@ toImpMiscOp maybeDest op = case op of
487483 ShowAny _ -> error " Shouldn't have ShowAny in simplified IR"
488484 ShowScalar x -> do
489485 resultTy <- getType $ PrimOp $ MiscOp op
490- Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@ (Leaf tabPtr)]) <- maybeAllocDest maybeDest resultTy
486+ Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@ (Leaf tabPtr)]) <- maybeAllocDest Nothing resultTy
491487 xScalar <- fromScalarAtom x
492488 size <- emitInstr $ IShowScalar tabPtr xScalar
493489 let size' = toScalarAtom size
@@ -496,13 +492,10 @@ toImpMiscOp maybeDest op = case op of
496492 return $ PairVal size' tab
497493 where
498494 fsa = fromScalarAtom
499- returnIExprVal x = returnVal $ toScalarAtom x
500- returnVal atom = case maybeDest of
501- Nothing -> return atom
502- Just dest -> storeAtom dest atom >> return atom
495+ returnIExprVal x = return $ toScalarAtom x
503496
504- toImpMemOp :: forall i o . Emits o => MaybeDest o -> MemOp (SAtom o ) -> SubstImpM i o (SAtom o )
505- toImpMemOp maybeDest op = case op of
497+ toImpMemOp :: forall i o . Emits o => MemOp (SAtom o ) -> SubstImpM i o (SAtom o )
498+ toImpMemOp op = case op of
506499 IOAlloc ty n -> do
507500 n' <- fsa n
508501 ptr <- emitInstr $ Alloc CPU ty n'
@@ -511,7 +504,7 @@ toImpMemOp maybeDest op = case op of
511504 ptr' <- fsa ptr
512505 emitStatement $ Free ptr'
513506 return UnitVal
514- PtrOffset arr (IdxRepVal 0 ) -> returnVal arr
507+ PtrOffset arr (IdxRepVal 0 ) -> return arr
515508 PtrOffset arr off -> do
516509 arr' <- fsa arr
517510 off' <- fsa off
@@ -527,32 +520,29 @@ toImpMemOp maybeDest op = case op of
527520 return UnitVal
528521 where
529522 fsa = fromScalarAtom
530- returnIExprVal x = returnVal $ toScalarAtom x
531- returnVal atom = case maybeDest of
532- Nothing -> return atom
533- Just dest -> storeAtom dest atom >> return atom
523+ returnIExprVal x = return $ toScalarAtom x
534524
535525toImpFor
536- :: Emits o => Maybe ( Dest o ) -> SType o -> Direction
526+ :: Emits o => SType o -> Direction
537527 -> IxDict SimpIR i -> LamExpr SimpIR i
538528 -> SubstImpM i o (SAtom o )
539- toImpFor maybeDest resultTy d ixDict (UnaryLamExpr b body) = do
529+ toImpFor resultTy d ixDict (UnaryLamExpr b body) = do
540530 ixTy <- ixTyFromDict =<< substM ixDict
541531 n <- indexSetSizeImp ixTy
542- dest <- maybeAllocDest maybeDest resultTy
532+ dest <- maybeAllocDest Nothing resultTy
543533 emitLoop (getNameHint b) d n \ i -> do
544534 idx <- unsafeFromOrdinalImp (sink ixTy) i
545535 ithDest <- indexDest (sink dest) idx
546536 void $ extendSubst (b @> SubstVal idx) $
547537 translateBlock (Just ithDest) body
548538 loadAtom dest
549- toImpFor _ _ _ _ _ = error " expected a lambda as the atom argument"
539+ toImpFor _ _ _ _ = error " expected a lambda as the atom argument"
550540
551- toImpHof :: Emits o => Maybe ( Dest o ) -> Hof SimpIR i -> SubstImpM i o (SAtom o )
552- toImpHof maybeDest hof = do
541+ toImpHof :: Emits o => Hof SimpIR i -> SubstImpM i o (SAtom o )
542+ toImpHof hof = do
553543 resultTy <- getTypeSubst (Hof hof)
554544 case hof of
555- For d ixDict lam -> toImpFor maybeDest resultTy d ixDict lam
545+ For d ixDict lam -> toImpFor resultTy d ixDict lam
556546 While body -> do
557547 body' <- buildBlockImp do
558548 ans <- fromScalarAtom =<< translateBlock Nothing body
@@ -565,12 +555,12 @@ toImpHof maybeDest hof = do
565555 rDest <- allocDest =<< getType r'
566556 storeAtom rDest r'
567557 extendSubst (h @> SubstVal (Con HeapVal ) <.> ref @> SubstVal (destToAtom rDest)) $
568- translateBlock maybeDest body
558+ translateBlock Nothing body
569559 RunWriter d (BaseMonoid e _) f -> do
570560 BinaryLamExpr h ref body <- return f
571561 let PairTy ansTy accTy = resultTy
572562 (aDest, wDest) <- case d of
573- Nothing -> destPairUnpack <$> maybeAllocDest maybeDest resultTy
563+ Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
574564 Just d' -> do
575565 aDest <- maybeAllocDest Nothing ansTy
576566 wDest <- atomToDest =<< substM d'
@@ -587,7 +577,7 @@ toImpHof maybeDest hof = do
587577 BinaryLamExpr h ref body <- return f
588578 let PairTy ansTy _ = resultTy
589579 (aDest, sDest) <- case d of
590- Nothing -> destPairUnpack <$> maybeAllocDest maybeDest resultTy
580+ Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
591581 Just d' -> do
592582 aDest <- maybeAllocDest Nothing ansTy
593583 sDest <- atomToDest =<< substM d'
@@ -596,8 +586,8 @@ toImpHof maybeDest hof = do
596586 void $ extendSubst (h @> SubstVal (Con HeapVal ) <.> ref @> SubstVal (destToAtom sDest)) $
597587 translateBlock (Just aDest) body
598588 PairVal <$> loadAtom aDest <*> loadAtom sDest
599- RunIO body-> translateBlock maybeDest body
600- RunInit body -> translateBlock maybeDest body
589+ RunIO body-> translateBlock Nothing body
590+ RunInit body -> translateBlock Nothing body
601591 where
602592 liftMonoidEmpty :: Emits n => SType n -> SAtom n -> SBuilderM n (SAtom n )
603593 liftMonoidEmpty accTy x = do
@@ -1205,11 +1195,11 @@ withFreshIBinder hint ty cont = do
12051195{-# INLINE withFreshIBinder #-}
12061196
12071197emitCall
1208- :: Emits n => MaybeDest n -> PiType SimpIR n
1198+ :: Emits n => PiType SimpIR n
12091199 -> ImpFunName n -> [SAtom n ] -> SubstImpM i n (SAtom n )
1210- emitCall maybeDest (PiType bs _ resultTy) f xs = do
1200+ emitCall (PiType bs _ resultTy) f xs = do
12111201 resultTy' <- applySubst (bs @@> map SubstVal xs) resultTy
1212- dest <- maybeAllocDest maybeDest resultTy'
1202+ dest <- maybeAllocDest Nothing resultTy'
12131203 argsImp <- forM xs \ x -> repValToList <$> atomToRepVal x
12141204 destImp <- repValToList <$> atomToRepVal (destToAtom dest)
12151205 let impArgs = concat argsImp ++ destImp
0 commit comments