Skip to content

Commit 207b0ea

Browse files
committed
Now every use site of translateBlock statically knows whether it passes Nothing or Just dest.
Split translateBlock into two functions specialized on those cases.
1 parent c0c2c6a commit 207b0ea

File tree

1 file changed

+36
-40
lines changed

1 file changed

+36
-40
lines changed

src/lib/Imp.hs

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ toImpFunction cc lam = do
6262
RefTy _ ansTy -> allocDestUnmanaged =<< substM ansTy
6363
_ -> error "Expected a reference type for body destination"
6464
extendSubst (destb @> SubstVal (destToAtom dest)) do
65-
void $ translateBlock Nothing body
65+
void $ translateBlock body
6666
resultAtom <- loadAtom dest
6767
repValToList <$> atomToRepVal resultAtom
6868
_ -> do
6969
(argAtoms, resultDest) <- interpretImpArgsWithCC cc (sink ty) vs
7070
extendSubst (bs @@> (SubstVal <$> argAtoms)) do
7171
(DestBlock destb body) <- return bodyAbs
7272
extendSubst (destb @> SubstVal (destToAtom (sink resultDest))) do
73-
void $ translateBlock Nothing body
73+
void $ translateBlock body
7474
return []
7575

7676
getNaryLamImpArgTypesWithCC :: EnvReader m
@@ -266,8 +266,12 @@ liftImpM cont = do
266266
-- === the actual pass ===
267267

268268
translateBlock :: forall i o. Emits o
269-
=> MaybeDest o -> SBlock i -> SubstImpM i o (SAtom o)
270-
translateBlock dest (Block _ decls result) =
269+
=> SBlock i -> SubstImpM i o (SAtom o)
270+
translateBlock (Block _ decls result) = translateDeclNest decls $ substM result
271+
272+
translateBlockWithDest :: forall i o. Emits o
273+
=> Dest o -> SBlock i -> SubstImpM i o (SAtom o)
274+
translateBlockWithDest dest (Block _ decls result) =
271275
translateDeclNest decls $ translateAtom dest result
272276

273277
translateDeclNestSubst
@@ -287,6 +291,12 @@ translateDeclNest decls cont = do
287291
withSubst s' cont
288292
{-# INLINE translateDeclNest #-}
289293

294+
translateAtom :: forall i o. Emits o
295+
=> Dest o -> SAtom i -> SubstImpM i o (SAtom o)
296+
translateAtom dest x = do
297+
x' <- substM x
298+
storeAtom dest x' >> return x'
299+
290300
translateExpr :: forall i o. Emits o => SExpr i -> SubstImpM i o (SAtom o)
291301
translateExpr expr = confuseGHC >>= \_ -> case expr of
292302
Hof hof -> toImpHof hof
@@ -313,7 +323,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
313323
case trySelectBranch e' of
314324
Just (con, arg) -> do
315325
Abs b body <- return $ alts !! con
316-
extendSubst (b @> SubstVal arg) $ translateBlock Nothing body
326+
extendSubst (b @> SubstVal arg) $ translateBlock body
317327
Nothing -> do
318328
RepVal sumTy (Branch (tag:xss)) <- atomToRepVal e'
319329
ts <- caseAltsBinderTys sumTy
@@ -323,11 +333,11 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
323333
where
324334
go tag xss = do
325335
tag' <- fromScalarAtom tag
326-
dest <- maybeAllocDest Nothing =<< substM ty
336+
dest <- allocDest =<< substM ty
327337
emitSwitch tag' (zip xss alts) $
328338
\(xs, Abs b body) ->
329339
void $ extendSubst (b @> SubstVal (sink xs)) $
330-
translateBlock (Just $ sink dest) body
340+
translateBlockWithDest (sink dest) body
331341
loadAtom dest
332342
DAMOp damOp -> case damOp of
333343
Seq d ixDict carry f -> do
@@ -338,12 +348,12 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
338348
emitLoop (getNameHint b) d n \i -> do
339349
idx <- unsafeFromOrdinalImp (sink ixTy) i
340350
void $ extendSubst (b @> SubstVal (PairVal idx (sink carry'))) $
341-
translateBlock Nothing body
351+
translateBlock body
342352
return carry'
343353
RememberDest d f -> do
344354
UnaryLamExpr b body <- return f
345355
d' <- substM d
346-
void $ extendSubst (b @> SubstVal d') $ translateBlock Nothing body
356+
void $ extendSubst (b @> SubstVal d') $ translateBlock body
347357
return d'
348358
Place ref val -> do
349359
val' <- substM val
@@ -356,21 +366,13 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
356366
TabCon _ ty rows -> do
357367
resultTy@(TabPi (TabPiType b _)) <- substM ty
358368
let ixTy = binderAnn b
359-
dest <- maybeAllocDest Nothing resultTy
369+
dest <- allocDest resultTy
360370
forM_ (zip [0..] rows) \(i, row) -> do
361371
row' <- substM row
362372
ithDest <- indexDest dest =<< unsafeFromOrdinalImp ixTy (IIdxRepVal i)
363373
storeAtom ithDest row'
364374
loadAtom dest
365375

366-
translateAtom :: forall i o. Emits o
367-
=> MaybeDest o -> SAtom i -> SubstImpM i o (SAtom o)
368-
translateAtom maybeDest x = do
369-
x' <- substM x
370-
case maybeDest of
371-
Nothing -> return x'
372-
Just dest -> storeAtom dest x' >> return x'
373-
374376
toImpRefOp :: Emits o
375377
=> SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o)
376378
toImpRefOp refDest' m = do
@@ -387,7 +389,7 @@ toImpRefOp refDest' m = do
387389
MPut x -> storeAtom refDest x >> return UnitVal
388390
MGet -> do
389391
Dest resultTy _ <- return refDest
390-
dest <- maybeAllocDest Nothing resultTy
392+
dest <- allocDest resultTy
391393
-- It might be more efficient to implement a specialized copy for dests
392394
-- than to go through a general purpose atom.
393395
storeAtom dest =<< loadAtom refDest
@@ -483,7 +485,7 @@ toImpMiscOp op = case op of
483485
ShowAny _ -> error "Shouldn't have ShowAny in simplified IR"
484486
ShowScalar x -> do
485487
resultTy <- getType $ PrimOp $ MiscOp op
486-
Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@(Leaf tabPtr)]) <- maybeAllocDest Nothing resultTy
488+
Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@(Leaf tabPtr)]) <- allocDest resultTy
487489
xScalar <- fromScalarAtom x
488490
size <- emitInstr $ IShowScalar tabPtr xScalar
489491
let size' = toScalarAtom size
@@ -529,12 +531,12 @@ toImpFor
529531
toImpFor resultTy d ixDict (UnaryLamExpr b body) = do
530532
ixTy <- ixTyFromDict =<< substM ixDict
531533
n <- indexSetSizeImp ixTy
532-
dest <- maybeAllocDest Nothing resultTy
534+
dest <- allocDest resultTy
533535
emitLoop (getNameHint b) d n \i -> do
534536
idx <- unsafeFromOrdinalImp (sink ixTy) i
535537
ithDest <- indexDest (sink dest) idx
536538
void $ extendSubst (b @> SubstVal idx) $
537-
translateBlock (Just ithDest) body
539+
translateBlockWithDest ithDest body
538540
loadAtom dest
539541
toImpFor _ _ _ _ = error "expected a lambda as the atom argument"
540542

@@ -545,7 +547,7 @@ toImpHof hof = do
545547
For d ixDict lam -> toImpFor resultTy d ixDict lam
546548
While body -> do
547549
body' <- buildBlockImp do
548-
ans <- fromScalarAtom =<< translateBlock Nothing body
550+
ans <- fromScalarAtom =<< translateBlock body
549551
return [ans]
550552
emitStatement $ IWhile body'
551553
return UnitVal
@@ -555,14 +557,14 @@ toImpHof hof = do
555557
rDest <- allocDest =<< getType r'
556558
storeAtom rDest r'
557559
extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom rDest)) $
558-
translateBlock Nothing body
560+
translateBlock body
559561
RunWriter d (BaseMonoid e _) f -> do
560562
BinaryLamExpr h ref body <- return f
561563
let PairTy ansTy accTy = resultTy
562564
(aDest, wDest) <- case d of
563-
Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
565+
Nothing -> destPairUnpack <$> allocDest resultTy
564566
Just d' -> do
565-
aDest <- maybeAllocDest Nothing ansTy
567+
aDest <- allocDest ansTy
566568
wDest <- atomToDest =<< substM d'
567569
return (aDest, wDest)
568570
e' <- substM e
@@ -571,23 +573,23 @@ toImpHof hof = do
571573
liftMonoidEmpty accTy' e''
572574
storeAtom wDest emptyVal
573575
void $ extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom wDest)) $
574-
translateBlock (Just aDest) body
576+
translateBlockWithDest aDest body
575577
PairVal <$> loadAtom aDest <*> loadAtom wDest
576578
RunState d s f -> do
577579
BinaryLamExpr h ref body <- return f
578580
let PairTy ansTy _ = resultTy
579581
(aDest, sDest) <- case d of
580-
Nothing -> destPairUnpack <$> maybeAllocDest Nothing resultTy
582+
Nothing -> destPairUnpack <$> allocDest resultTy
581583
Just d' -> do
582-
aDest <- maybeAllocDest Nothing ansTy
584+
aDest <- allocDest ansTy
583585
sDest <- atomToDest =<< substM d'
584586
return (aDest, sDest)
585587
storeAtom sDest =<< substM s
586588
void $ extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom sDest)) $
587-
translateBlock (Just aDest) body
589+
translateBlockWithDest aDest body
588590
PairVal <$> loadAtom aDest <*> loadAtom sDest
589-
RunIO body-> translateBlock Nothing body
590-
RunInit body -> translateBlock Nothing body
591+
RunIO body-> translateBlock body
592+
RunInit body -> translateBlock body
591593
where
592594
liftMonoidEmpty :: Emits n => SType n -> SAtom n -> SBuilderM n (SAtom n)
593595
liftMonoidEmpty accTy x = do
@@ -610,8 +612,6 @@ data Dest (n::S) = Dest
610612
(Tree (IExpr n)) -- underlying scalar values
611613
deriving (Show)
612614

613-
type MaybeDest n = Maybe (Dest n)
614-
615615
data LeafType n where
616616
LeafType :: TypeCtx SimpIR n l -> BaseType -> LeafType n
617617

@@ -972,10 +972,6 @@ allocDestUnmanaged = allocDestWithAllocContext Unmanaged
972972
allocDest :: Emits n => SType n -> SubstImpM i n (Dest n)
973973
allocDest = allocDestWithAllocContext Managed
974974

975-
maybeAllocDest :: Emits n => Maybe (Dest n) -> SType n -> SubstImpM i n (Dest n)
976-
maybeAllocDest (Just d) _ = return d
977-
maybeAllocDest Nothing t = allocDest t
978-
979975
storeAtom :: Emits n => Dest n -> SAtom n -> SubstImpM i n ()
980976
storeAtom dest x = storeRepVal dest =<< atomToRepVal x
981977

@@ -1199,7 +1195,7 @@ emitCall
11991195
-> ImpFunName n -> [SAtom n] -> SubstImpM i n (SAtom n)
12001196
emitCall (PiType bs _ resultTy) f xs = do
12011197
resultTy' <- applySubst (bs @@> map SubstVal xs) resultTy
1202-
dest <- maybeAllocDest Nothing resultTy'
1198+
dest <- allocDest resultTy'
12031199
argsImp <- forM xs \x -> repValToList <$> atomToRepVal x
12041200
destImp <- repValToList <$> atomToRepVal (destToAtom dest)
12051201
let impArgs = concat argsImp ++ destImp
@@ -1415,7 +1411,7 @@ indexSetSizeImp (IxType _ dict) = do
14151411
appSpecializedIxMethod :: Emits n => LamExpr SimpIR n -> [SAtom n] -> SubstImpM i n (SAtom n)
14161412
appSpecializedIxMethod simpLam args = do
14171413
LamExpr bs body <- return simpLam
1418-
dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock Nothing body
1414+
dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock body
14191415

14201416
-- === Abstracting link-time objects ===
14211417

0 commit comments

Comments
 (0)