Skip to content

Commit c0c2c6a

Browse files
committed
Simplify vestigial destination passing implementation in Imp.hs.
To wit, the only time translateExpr could ever get a non-Nothing dest was when translating the result of a block in translateBlock, which is statically always an Atom. So specialize that case to a new translateAtom, and remove the Maybe Dest argument from translateExpr. All of translateExpr's helpers now also statically get Nothing dests, so remove those Maybe Dest arguments as well.
1 parent 31c7f98 commit c0c2c6a

File tree

1 file changed

+64
-74
lines changed

1 file changed

+64
-74
lines changed

src/lib/Imp.hs

Lines changed: 64 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,16 @@ liftImpM cont = do
267267

268268
translateBlock :: forall i o. Emits o
269269
=> MaybeDest o -> SBlock i -> SubstImpM i o (SAtom o)
270-
translateBlock dest (Block _ decls result) = translateDeclNest decls $ translateExpr dest $ Atom result
270+
translateBlock dest (Block _ decls result) =
271+
translateDeclNest decls $ translateAtom dest result
271272

272273
translateDeclNestSubst
273274
:: Emits o => Subst AtomSubstVal l o
274275
-> Nest SDecl l i' -> SubstImpM i o (Subst AtomSubstVal i' o)
275276
translateDeclNestSubst !s = \case
276277
Empty -> return s
277278
Nest (Let b (DeclBinding _ _ expr)) rest -> do
278-
x <- withSubst s $ translateExpr Nothing expr
279+
x <- withSubst s $ translateExpr expr
279280
translateDeclNestSubst (s <>> (b@>SubstVal x)) rest
280281

281282
translateDeclNest :: Emits o
@@ -286,15 +287,14 @@ translateDeclNest decls cont = do
286287
withSubst s' cont
287288
{-# INLINE translateDeclNest #-}
288289

289-
translateExpr :: forall i o. Emits o
290-
=> MaybeDest o -> SExpr i -> SubstImpM i o (SAtom o)
291-
translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
292-
Hof hof -> toImpHof maybeDest hof
290+
translateExpr :: forall i o. Emits o => SExpr i -> SubstImpM i o (SAtom o)
291+
translateExpr expr = confuseGHC >>= \_ -> case expr of
292+
Hof hof -> toImpHof hof
293293
TopApp f' xs' -> do
294294
f <- substM f'
295295
xs <- mapM substM xs'
296296
lookupTopFun f >>= \case
297-
DexTopFun _ piTy _ _ -> emitCall maybeDest piTy f $ toList xs
297+
DexTopFun _ piTy _ _ -> emitCall piTy f $ toList xs
298298
FFITopFun _ _ -> do
299299
resultTy <- getType $ TopApp f xs
300300
scalarArgs <- liftM toList $ mapM fromScalarAtom xs
@@ -304,16 +304,16 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
304304
xs <- mapM substM xs'
305305
f <- atomToRepVal =<< substM f'
306306
repValAtom =<< naryIndexRepVal f (toList xs)
307-
Atom x -> substM x >>= returnVal
307+
Atom x -> substM x
308308
-- Inlining the traversal helps GHC sink the substM below the case inside toImpOp.
309-
PrimOp op -> (inline traversePrimOp) substM op >>= toImpOp maybeDest
310-
RefOp refDest eff -> toImpRefOp maybeDest refDest eff
309+
PrimOp op -> (inline traversePrimOp) substM op >>= toImpOp
310+
RefOp refDest eff -> toImpRefOp refDest eff
311311
Case e alts ty _ -> do
312312
e' <- substM e
313313
case trySelectBranch e' of
314314
Just (con, arg) -> do
315315
Abs b body <- return $ alts !! con
316-
extendSubst (b @> SubstVal arg) $ translateBlock maybeDest body
316+
extendSubst (b @> SubstVal arg) $ translateBlock Nothing body
317317
Nothing -> do
318318
RepVal sumTy (Branch (tag:xss)) <- atomToRepVal e'
319319
ts <- caseAltsBinderTys sumTy
@@ -323,7 +323,7 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
323323
where
324324
go tag xss = do
325325
tag' <- fromScalarAtom tag
326-
dest <- maybeAllocDest maybeDest =<< substM ty
326+
dest <- maybeAllocDest Nothing =<< substM ty
327327
emitSwitch tag' (zip xss alts) $
328328
\(xs, Abs b body) ->
329329
void $ extendSubst (b @> SubstVal (sink xs)) $
@@ -339,9 +339,7 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
339339
idx <- unsafeFromOrdinalImp (sink ixTy) i
340340
void $ extendSubst (b @> SubstVal (PairVal idx (sink carry'))) $
341341
translateBlock Nothing body
342-
case maybeDest of
343-
Nothing -> return carry'
344-
Just _ -> error "Unexpected dest"
342+
return carry'
345343
RememberDest d f -> do
346344
UnaryLamExpr b body <- return f
347345
d' <- substM d
@@ -350,48 +348,52 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
350348
Place ref val -> do
351349
val' <- substM val
352350
refDest <- atomToDest =<< substM ref
353-
storeAtom refDest val' >> returnVal UnitVal
351+
storeAtom refDest val' >> return UnitVal
354352
Freeze ref -> loadAtom =<< atomToDest =<< substM ref
355353
AllocDest ty -> do
356354
d <- liftM destToAtom $ allocDest =<< substM ty
357-
returnVal d
355+
return d
358356
TabCon _ ty rows -> do
359357
resultTy@(TabPi (TabPiType b _)) <- substM ty
360358
let ixTy = binderAnn b
361-
dest <- maybeAllocDest maybeDest resultTy
359+
dest <- maybeAllocDest Nothing resultTy
362360
forM_ (zip [0..] rows) \(i, row) -> do
363361
row' <- substM row
364362
ithDest <- indexDest dest =<< unsafeFromOrdinalImp ixTy (IIdxRepVal i)
365363
storeAtom ithDest row'
366364
loadAtom dest
367-
where
368-
returnVal atom = case maybeDest of
369-
Nothing -> return atom
370-
Just dest -> storeAtom dest atom >> return atom
365+
366+
translateAtom :: forall i o. Emits o
367+
=> MaybeDest o -> SAtom i -> SubstImpM i o (SAtom o)
368+
translateAtom maybeDest x = do
369+
x' <- substM x
370+
case maybeDest of
371+
Nothing -> return x'
372+
Just dest -> storeAtom dest x' >> return x'
371373

372374
toImpRefOp :: Emits o
373-
=> MaybeDest o -> SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o)
374-
toImpRefOp maybeDest refDest' m = do
375+
=> SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o)
376+
toImpRefOp refDest' m = do
375377
refDest <- atomToDest =<< substM refDest'
376378
substM m >>= \case
377-
MAsk -> returnVal =<< loadAtom refDest
379+
MAsk -> loadAtom refDest
378380
MExtend (BaseMonoid _ combine) x -> do
379381
xTy <- getType x
380382
refVal <- loadAtom refDest
381383
result <- liftBuilderImp $
382384
liftMonoidCombine (sink xTy) (sink combine) (sink refVal) (sink x)
383385
storeAtom refDest result
384-
returnVal UnitVal
385-
MPut x -> storeAtom refDest x >> returnVal UnitVal
386+
return UnitVal
387+
MPut x -> storeAtom refDest x >> return UnitVal
386388
MGet -> do
387389
Dest resultTy _ <- return refDest
388-
dest <- maybeAllocDest maybeDest resultTy
390+
dest <- maybeAllocDest Nothing resultTy
389391
-- It might be more efficient to implement a specialized copy for dests
390392
-- than to go through a general purpose atom.
391393
storeAtom dest =<< loadAtom refDest
392394
loadAtom dest
393-
IndexRef i -> returnVal =<< (destToAtom <$> indexDest refDest i)
394-
ProjRef ~(ProjectProduct i) -> returnVal $ destToAtom $ projectDest i refDest
395+
IndexRef i -> destToAtom <$> indexDest refDest i
396+
ProjRef ~(ProjectProduct i) -> return $ destToAtom $ projectDest i refDest
395397
where
396398
liftMonoidCombine
397399
:: Emits n => SType n -> LamExpr SimpIR n
@@ -413,17 +415,14 @@ toImpRefOp maybeDest refDest' m = do
413415
liftMonoidCombine eltTy' (sink bc) xElt yElt
414416
_ -> error $ "Base monoid type mismatch: can't lift " ++
415417
pprint baseTy ++ " to " ++ pprint accTy
416-
returnVal atom = case maybeDest of
417-
Nothing -> return atom
418-
Just dest -> storeAtom dest atom >> return atom
419418

420419
toImpOp :: forall i o .
421-
Emits o => MaybeDest o -> PrimOp (SAtom o) -> SubstImpM i o (SAtom o)
422-
toImpOp maybeDest op = case op of
420+
Emits o => PrimOp (SAtom o) -> SubstImpM i o (SAtom o)
421+
toImpOp op = case op of
423422
BinOp binOp x y -> returnIExprVal =<< emitInstr =<< (IBinOp binOp <$> fsa x <*> fsa y)
424423
UnOp unOp x -> returnIExprVal =<< emitInstr =<< (IUnOp unOp <$> fsa x)
425-
MemOp memOp -> toImpMemOp maybeDest memOp
426-
MiscOp miscOp -> toImpMiscOp maybeDest miscOp
424+
MemOp memOp -> toImpMemOp memOp
425+
MiscOp miscOp -> toImpMiscOp miscOp
427426
VectorOp (VectorBroadcast val vty) -> do
428427
val' <- fsa val
429428
emitInstr (IVectorBroadcast val' $ toIVectorType vty) >>= returnIExprVal
@@ -433,20 +432,17 @@ toImpOp maybeDest op = case op of
433432
refi <- destToAtom <$> indexDest refDest i
434433
refi' <- fsa refi
435434
let PtrType (addrSpace, _) = getIType refi'
436-
returnVal =<< case vty of
435+
case vty of
437436
BaseTy vty'@(Vector _ _) -> do
438437
resultVal <- cast refi' (PtrType (addrSpace, vty'))
439438
repValAtom $ RepVal (RefTy (Con HeapVal) vty) (Leaf resultVal)
440439
_ -> error "Expected a vector type"
441440
where
442441
fsa = fromScalarAtom
443-
returnIExprVal x = returnVal $ toScalarAtom x
444-
returnVal atom = case maybeDest of
445-
Nothing -> return atom
446-
Just dest -> storeAtom dest atom >> return atom
442+
returnIExprVal x = return $ toScalarAtom x
447443

448-
toImpMiscOp :: Emits o => MaybeDest o -> MiscOp (SAtom o) -> SubstImpM i o (SAtom o)
449-
toImpMiscOp maybeDest op = case op of
444+
toImpMiscOp :: Emits o => MiscOp (SAtom o) -> SubstImpM i o (SAtom o)
445+
toImpMiscOp op = case op of
450446
ThrowError resultTy -> do
451447
emitStatement IThrowError
452448
buildGarbageVal resultTy
@@ -465,19 +461,19 @@ toImpMiscOp maybeDest op = case op of
465461
assertEq srcRep destRep $
466462
"representation types don't match: " ++ pprint srcRep ++ " != " ++ pprint destRep
467463
RepVal _ tree <- atomToRepVal x
468-
returnVal =<< repValAtom (RepVal resultTy tree)
464+
repValAtom (RepVal resultTy tree)
469465
GarbageVal resultTy -> buildGarbageVal resultTy
470466
Select p x y -> do
471467
BaseTy _ <- getType x
472468
returnIExprVal =<< emitInstr =<< (ISelect <$> fsa p <*> fsa x <*> fsa y)
473469
SumTag con -> case con of
474-
Con (SumCon _ tag _) -> returnVal $ TagRepVal $ fromIntegral tag
470+
Con (SumCon _ tag _) -> return $ TagRepVal $ fromIntegral tag
475471
RepValAtom dRepVal -> go dRepVal
476472
_ -> error $ "Not a data constructor: " ++ pprint con
477473
where go dRepVal = do
478474
RepVal _ (Branch (tag:_)) <- return dRepVal
479475
return $ RepValAtom $ RepVal TagRepTy tag
480-
ToEnum ty i -> returnVal =<< case ty of
476+
ToEnum ty i -> case ty of
481477
SumTy cases -> do
482478
i' <- fromScalarAtom i
483479
return $ RepValAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases
@@ -487,7 +483,7 @@ toImpMiscOp maybeDest op = case op of
487483
ShowAny _ -> error "Shouldn't have ShowAny in simplified IR"
488484
ShowScalar x -> do
489485
resultTy <- getType $ PrimOp $ MiscOp op
490-
Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@(Leaf tabPtr)]) <- maybeAllocDest maybeDest resultTy
486+
Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@(Leaf tabPtr)]) <- maybeAllocDest Nothing resultTy
491487
xScalar <- fromScalarAtom x
492488
size <- emitInstr $ IShowScalar tabPtr xScalar
493489
let size' = toScalarAtom size
@@ -496,13 +492,10 @@ toImpMiscOp maybeDest op = case op of
496492
return $ PairVal size' tab
497493
where
498494
fsa = fromScalarAtom
499-
returnIExprVal x = returnVal $ toScalarAtom x
500-
returnVal atom = case maybeDest of
501-
Nothing -> return atom
502-
Just dest -> storeAtom dest atom >> return atom
495+
returnIExprVal x = return $ toScalarAtom x
503496

504-
toImpMemOp :: forall i o . Emits o => MaybeDest o -> MemOp (SAtom o) -> SubstImpM i o (SAtom o)
505-
toImpMemOp maybeDest op = case op of
497+
toImpMemOp :: forall i o . Emits o => MemOp (SAtom o) -> SubstImpM i o (SAtom o)
498+
toImpMemOp op = case op of
506499
IOAlloc ty n -> do
507500
n' <- fsa n
508501
ptr <- emitInstr $ Alloc CPU ty n'
@@ -511,7 +504,7 @@ toImpMemOp maybeDest op = case op of
511504
ptr' <- fsa ptr
512505
emitStatement $ Free ptr'
513506
return UnitVal
514-
PtrOffset arr (IdxRepVal 0) -> returnVal arr
507+
PtrOffset arr (IdxRepVal 0) -> return arr
515508
PtrOffset arr off -> do
516509
arr' <- fsa arr
517510
off' <- fsa off
@@ -527,32 +520,29 @@ toImpMemOp maybeDest op = case op of
527520
return UnitVal
528521
where
529522
fsa = fromScalarAtom
530-
returnIExprVal x = returnVal $ toScalarAtom x
531-
returnVal atom = case maybeDest of
532-
Nothing -> return atom
533-
Just dest -> storeAtom dest atom >> return atom
523+
returnIExprVal x = return $ toScalarAtom x
534524

535525
toImpFor
536-
:: Emits o => Maybe (Dest o) -> SType o -> Direction
526+
:: Emits o => SType o -> Direction
537527
-> IxDict SimpIR i -> LamExpr SimpIR i
538528
-> SubstImpM i o (SAtom o)
539-
toImpFor maybeDest resultTy d ixDict (UnaryLamExpr b body) = do
529+
toImpFor resultTy d ixDict (UnaryLamExpr b body) = do
540530
ixTy <- ixTyFromDict =<< substM ixDict
541531
n <- indexSetSizeImp ixTy
542-
dest <- maybeAllocDest maybeDest resultTy
532+
dest <- maybeAllocDest Nothing resultTy
543533
emitLoop (getNameHint b) d n \i -> do
544534
idx <- unsafeFromOrdinalImp (sink ixTy) i
545535
ithDest <- indexDest (sink dest) idx
546536
void $ extendSubst (b @> SubstVal idx) $
547537
translateBlock (Just ithDest) body
548538
loadAtom dest
549-
toImpFor _ _ _ _ _ = error "expected a lambda as the atom argument"
539+
toImpFor _ _ _ _ = error "expected a lambda as the atom argument"
550540

551-
toImpHof :: Emits o => Maybe (Dest o) -> Hof SimpIR i -> SubstImpM i o (SAtom o)
552-
toImpHof maybeDest hof = do
541+
toImpHof :: Emits o => Hof SimpIR i -> SubstImpM i o (SAtom o)
542+
toImpHof hof = do
553543
resultTy <- getTypeSubst (Hof hof)
554544
case hof of
555-
For d ixDict lam -> toImpFor maybeDest resultTy d ixDict lam
545+
For d ixDict lam -> toImpFor resultTy d ixDict lam
556546
While body -> do
557547
body' <- buildBlockImp do
558548
ans <- fromScalarAtom =<< translateBlock Nothing body
@@ -565,12 +555,12 @@ toImpHof maybeDest hof = do
565555
rDest <- allocDest =<< getType r'
566556
storeAtom rDest r'
567557
extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom rDest)) $
568-
translateBlock maybeDest body
558+
translateBlock Nothing body
569559
RunWriter d (BaseMonoid e _) f -> do
570560
BinaryLamExpr h ref body <- return f
571561
let PairTy ansTy accTy = resultTy
572562
(aDest, wDest) <- case d of
573-
Nothing -> destPairUnpack <$> maybeAllocDest maybeDest resultTy
563+
Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
574564
Just d' -> do
575565
aDest <- maybeAllocDest Nothing ansTy
576566
wDest <- atomToDest =<< substM d'
@@ -587,7 +577,7 @@ toImpHof maybeDest hof = do
587577
BinaryLamExpr h ref body <- return f
588578
let PairTy ansTy _ = resultTy
589579
(aDest, sDest) <- case d of
590-
Nothing -> destPairUnpack <$> maybeAllocDest maybeDest resultTy
580+
Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
591581
Just d' -> do
592582
aDest <- maybeAllocDest Nothing ansTy
593583
sDest <- atomToDest =<< substM d'
@@ -596,8 +586,8 @@ toImpHof maybeDest hof = do
596586
void $ extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom sDest)) $
597587
translateBlock (Just aDest) body
598588
PairVal <$> loadAtom aDest <*> loadAtom sDest
599-
RunIO body-> translateBlock maybeDest body
600-
RunInit body -> translateBlock maybeDest body
589+
RunIO body-> translateBlock Nothing body
590+
RunInit body -> translateBlock Nothing body
601591
where
602592
liftMonoidEmpty :: Emits n => SType n -> SAtom n -> SBuilderM n (SAtom n)
603593
liftMonoidEmpty accTy x = do
@@ -1205,11 +1195,11 @@ withFreshIBinder hint ty cont = do
12051195
{-# INLINE withFreshIBinder #-}
12061196

12071197
emitCall
1208-
:: Emits n => MaybeDest n -> PiType SimpIR n
1198+
:: Emits n => PiType SimpIR n
12091199
-> ImpFunName n -> [SAtom n] -> SubstImpM i n (SAtom n)
1210-
emitCall maybeDest (PiType bs _ resultTy) f xs = do
1200+
emitCall (PiType bs _ resultTy) f xs = do
12111201
resultTy' <- applySubst (bs @@> map SubstVal xs) resultTy
1212-
dest <- maybeAllocDest maybeDest resultTy'
1202+
dest <- maybeAllocDest Nothing resultTy'
12131203
argsImp <- forM xs \x -> repValToList <$> atomToRepVal x
12141204
destImp <- repValToList <$> atomToRepVal (destToAtom dest)
12151205
let impArgs = concat argsImp ++ destImp

0 commit comments

Comments
 (0)