Skip to content

Commit 89343a2

Browse files
authored
Merge pull request #1294 from axch/case-of-case
Case of case and other optimizations in the inliner
2 parents ab43029 + df126ba commit 89343a2

File tree

7 files changed

+119
-28
lines changed

7 files changed

+119
-28
lines changed

src/lib/Builder.hs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -768,15 +768,15 @@ injectAltResult sumTys con (Abs b body) = liftBuilder do
768768

769769
-- TODO: consider a version with nonempty list of alternatives where we figure
770770
-- out the result type from one of the alts rather than providing it explicitly
771-
buildCase :: (Emits n, ScopableBuilder r m)
772-
=> Atom r n -> Type r n
773-
-> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l))
774-
-> m n (Atom r n)
775-
buildCase scrut resultTy indexedAltBody = do
771+
buildCase' :: (Emits n, ScopableBuilder r m)
772+
=> Atom r n -> Type r n
773+
-> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l))
774+
-> m n (Expr r n)
775+
buildCase' scrut resultTy indexedAltBody = do
776776
case trySelectBranch scrut of
777777
Just (i, arg) -> do
778778
Distinct <- getDistinct
779-
indexedAltBody i $ sink arg
779+
Atom <$> indexedAltBody i (sink arg)
780780
Nothing -> do
781781
scrutTy <- getType scrut
782782
altBinderTys <- caseAltsBinderTys scrutTy
@@ -786,7 +786,13 @@ buildCase scrut resultTy indexedAltBody = do
786786
eff <- getEffects blk
787787
return $ blk `PairE` eff
788788
return (Abs b' body, ignoreHoistFailure $ hoist b' eff')
789-
emitExpr $ Case scrut alts resultTy $ mconcat effs
789+
return $ Case scrut alts resultTy $ mconcat effs
790+
791+
buildCase :: (Emits n, ScopableBuilder r m)
792+
=> Atom r n -> Type r n
793+
-> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l))
794+
-> m n (Atom r n)
795+
buildCase s r b = emitExprToAtom =<< buildCase' s r b
790796

791797
buildEffLam
792798
:: ScopableBuilder r m

src/lib/CheapReduction.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module CheapReduction
1515
, unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType
1616
, liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..)
1717
, visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2
18-
, visitBinders, visitPiDefault)
18+
, visitBinders, visitPiDefault, visitAlt)
1919
where
2020

2121
import Control.Applicative

src/lib/Inline.hs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import IRVariants
1616
import Name
1717
import Subst
1818
import Occurrence hiding (Var)
19+
import Optimize
1920
import Types.Core
2021
import Types.Primitives
2122

@@ -90,7 +91,7 @@ inlineDeclsSubst = \case
9091
s <- getSubst
9192
extendSubst (b @> SubstVal (SuspEx expr s)) $ inlineDeclsSubst rest
9293
else do
93-
expr' <- inlineExpr Stop expr
94+
expr' <- inlineExpr Stop expr >>= (liftEnvReaderM . peepholeExpr)
9495
-- If the inliner starts moving effectful expressions, it may become
9596
-- necessary to query the effects of the new expression here.
9697
let presInfo = resolveWorkConservation ann expr'
@@ -248,6 +249,9 @@ data Context (from::E) (to::E) (o::S) where
248249
Stop :: Context e e o
249250
TabAppCtx :: [SAtom i] -> Subst InlineSubstVal i o
250251
-> Context SExpr e o -> Context SExpr e o
252+
CaseCtx :: [SAlt i] -> SType i -> EffectRow SimpIR i
253+
-> Subst InlineSubstVal i o
254+
-> Context SExpr e o -> Context SExpr e o
251255
EmitToAtomCtx :: Context SAtom e o -> Context SExpr e o
252256
EmitToNameCtx :: Context SAtomName e o -> Context SAtom e o
253257

@@ -271,6 +275,9 @@ inlineExpr ctx = \case
271275
TabApp tbl ixs -> do
272276
s <- getSubst
273277
inlineAtom (TabAppCtx ixs s ctx) tbl
278+
Case scrut alts resultTy effs -> do
279+
s <- getSubst
280+
inlineAtom (CaseCtx alts resultTy effs s ctx) scrut
274281
expr -> visitGeneric expr >>= reconstruct ctx
275282

276283
inlineAtom :: Emits o => Context SExpr e o -> SAtom i -> InlineM i o (e o)
@@ -340,12 +347,18 @@ instance Inlinable SBlock where
340347
effs' <- inline Stop effs -- TODO Really?
341348
reconstruct ctx $ Block (BlockAnn ty' effs') decls' ans'
342349

350+
inlineBlockEmits :: Emits o => Context SExpr e2 o -> SBlock i -> InlineM i o (e2 o)
351+
inlineBlockEmits ctx (Block _ decls ans) = do
352+
inlineDecls decls $ inlineAtom ctx ans
353+
343354
-- Still using InlineM because we may call back into inlining, and we wish to
344355
-- retain our output binding environment.
345356
reconstruct :: Emits o => Context e1 e2 o -> e1 o -> InlineM i o (e2 o)
346357
reconstruct ctx e = case ctx of
347358
Stop -> return e
348359
TabAppCtx ixs s ctx' -> withSubst s $ reconstructTabApp ctx' e ixs
360+
CaseCtx alts resultTy effs s ctx' ->
361+
withSubst s $ reconstructCase ctx' e alts resultTy effs
349362
EmitToAtomCtx ctx' -> emitExprToAtom e >>= reconstruct ctx'
350363
EmitToNameCtx ctx' -> emit (Atom e) >>= reconstruct ctx'
351364
{-# INLINE reconstruct #-}
@@ -404,5 +417,38 @@ reconstructTabApp ctx expr ixs =
404417
ixs' <- mapM (inline Stop) ixs
405418
reconstruct ctx $ TabApp array' ixs'
406419

420+
reconstructCase :: Emits o
421+
=> Context SExpr e o -> SExpr o -> [SAlt i] -> SType i -> EffectRow SimpIR i
422+
-> InlineM i o (e o)
423+
reconstructCase ctx scrutExpr alts resultTy effs =
424+
case scrutExpr of
425+
Case sscrut salts _ _ -> do
426+
-- Perform case-of-case optimization
427+
-- TODO Add join points to reduce code duplication (and repeated inlining)
428+
-- of the arms of the outer case
429+
resultTy' <- inline Stop resultTy
430+
reconstruct ctx =<< (buildCase' sscrut resultTy' \i val -> do
431+
ans <- applyAbs (sink $ salts !! i) (SubstVal val) >>= emitBlock
432+
buildCase ans (sink resultTy') \j jval -> do
433+
Abs b body <- return $ alts !! j
434+
extendSubst (b @> (SubstVal $ DoneEx $ Atom jval)) do
435+
inlineBlockEmits Stop body >>= emitExprToAtom)
436+
_ -> do
437+
-- Attempt case-of-known-constructor optimization
438+
-- I can't use `buildCase` here because I want to propagate the incoming
439+
-- context `ctx` into the selected alternative if the optimization fires,
440+
-- but leave it around the whole reconstructed `Case` if it doesn't.
441+
scrut <- emitExprToAtom scrutExpr
442+
case trySelectBranch scrut of
443+
Just (i, val) -> do
444+
Abs b body <- return $ alts !! i
445+
extendSubst (b @> (SubstVal $ DoneEx $ Atom val)) do
446+
inlineBlockEmits ctx body
447+
Nothing -> do
448+
alts' <- mapM visitAlt alts
449+
resultTy' <- inline Stop resultTy
450+
effs' <- inline Stop effs
451+
reconstruct ctx $ Case scrut alts' resultTy' effs'
452+
407453
instance Inlinable (EffectRow SimpIR)
408454
instance Inlinable (EffectAndType SimpIR)

src/lib/Optimize.hs

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
{-# LANGUAGE UndecidableInstances #-}
88

99
module Optimize
10-
( optimize, peepholeOp
10+
( optimize, peepholeOp, peepholeExpr
1111
, hoistLoopInvariant, hoistLoopInvariantDest
1212
, dceTop, dceTopDest
1313
, foldCast ) where
@@ -43,7 +43,7 @@ optimize = dceTop -- Clean up user code
4343

4444
-- === Peephole optimizations ===
4545

46-
peepholeOp :: PrimOp SimpIR o -> EnvReaderM o (Either (SAtom o) (PrimOp SimpIR o))
46+
peepholeOp :: PrimOp SimpIR o -> EnvReaderM o (SExpr o)
4747
peepholeOp op = case op of
4848
MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case foldCast sTy l of
4949
Just l' -> lit l'
@@ -72,14 +72,14 @@ peepholeOp op = case op of
7272
return $ lit $ Word8Lit $ lv .|. rv
7373
BinOp BAnd (Con (Lit (Word8Lit lv))) (Con (Lit (Word8Lit rv))) ->
7474
return $ lit $ Word8Lit $ lv .&. rv
75-
MiscOp (ToEnum ty (Con (Lit (Word8Lit tag)))) -> Left <$> case ty of
76-
SumTy cases -> return $ SumVal cases (fromIntegral tag) UnitVal
75+
MiscOp (ToEnum ty (Con (Lit (Word8Lit tag)))) -> case ty of
76+
SumTy cases -> return $ Atom $ SumVal cases (fromIntegral tag) UnitVal
7777
_ -> error "Ill typed ToEnum?"
78-
MiscOp (SumTag (SumVal _ tag _)) -> return $ lit $ Word8Lit $ fromIntegral tag
78+
MiscOp (SumTag (SumVal _ tag _)) -> return $ lit $ Word8Lit $ fromIntegral tag
7979
_ -> return noop
8080
where
81-
noop = Right op
82-
lit = Left . Con . Lit
81+
noop = PrimOp op
82+
lit = Atom . Con . Lit
8383

8484
cmp :: Ord a => CmpOp -> a -> a -> Bool
8585
cmp = \case
@@ -188,9 +188,9 @@ foldCast sTy l = case sTy of
188188
compare (0 - countTrailingZeros (round @b @a a))
189189
(0 - countTrailingZeros (round @b @a b))
190190

191-
peepholeExpr :: SExpr o -> EnvReaderM o (Either (SAtom o) (SExpr o))
191+
peepholeExpr :: SExpr o -> EnvReaderM o (SExpr o)
192192
peepholeExpr expr = case expr of
193-
PrimOp op -> fmap PrimOp <$> peepholeOp op
193+
PrimOp op -> peepholeOp op
194194
TabApp (Var t) [IdxRepVal ord] ->
195195
lookupAtomName t <&> \case
196196
LetBound (DeclBinding ann _ (TabCon Nothing tabTy elems))
@@ -199,12 +199,12 @@ peepholeExpr expr = case expr of
199199
-- For example, it might be coming from an unsafe_from_ordinal that is
200200
-- under a case branch that would be dead for all invalid indices.
201201
if 0 <= ord && fromIntegral ord < length elems
202-
then Left $ elems !! fromIntegral ord
203-
else Right expr
204-
_ -> Right expr
202+
then Atom $ elems !! fromIntegral ord
203+
else expr
204+
_ -> expr
205205
-- TODO: Apply a function to literals when it has a cheap body?
206206
-- Think, partial evaluation of threefry.
207-
_ -> return $ Right expr
207+
_ -> return expr
208208
where isFinTabTy = \case
209209
TabPi (TabPiType (_:>(IxType _ (IxDictRawFin _))) _) -> True
210210
_ -> False
@@ -277,9 +277,8 @@ ulExpr expr = case expr of
277277
_ -> nothingSpecial
278278
where
279279
inc i = modify \(ULS n) -> ULS (n + i)
280-
nothingSpecial = inc 1 >> (visitGeneric expr >>= liftEnvReaderM . peepholeExpr) >>= \case
281-
Left x -> return x
282-
Right e -> emitExpr e
280+
nothingSpecial = inc 1 >> (visitGeneric expr >>= liftEnvReaderM . peepholeExpr)
281+
>>= emitExprToAtom
283282
unrollBlowupThreshold = 12
284283
withLocalAccounting m = do
285284
oldCost <- get

src/lib/Simplify.hs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -758,9 +758,7 @@ simplifyGenericOp op = do
758758
(substM >=> getRepType)
759759
(simplifyAtom >=> toDataAtomIgnoreRecon)
760760
(error "shouldn't have lambda left")
761-
result <- liftEnvReaderM (peepholeOp $ toPrimOp op') >>= \case
762-
Left a -> return a
763-
Right op'' -> emitOp op''
761+
result <- liftEnvReaderM (peepholeOp $ toPrimOp op') >>= emitExprToAtom
764762
liftSimpAtom ty result
765763
{-# INLINE simplifyGenericOp #-}
766764

src/lib/Types/Core.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ type SAtom = Atom SimpIR
425425
type SType = Type SimpIR
426426
type SExpr = Expr SimpIR
427427
type SBlock = Block SimpIR
428+
type SAlt = Alt SimpIR
428429
type SDecl = Decl SimpIR
429430
type SDecls = Decls SimpIR
430431
type SAtomName = AtomName SimpIR

tests/inline-tests.dx

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,44 @@ def id'(x:Nat) -> Nat = x
8787
sum (for i:(Fin 2) j:(..i). ordinal j)[ix]
8888
-- CHECK: 1
8989
-- CHECK-NOT: Compiler bug
90+
91+
-- CHECK-LABEL: Inlining simplifies case-of-known-constructor
92+
"Inlining simplifies case-of-known-constructor"
93+
94+
-- Inlining xs exposes a case-of-known-constructor opportunity here;
95+
-- the first inlining pass doesn't take it (yet) because it's
96+
-- conservative about inlining `i` into the body of `xs`, but the
97+
-- second pass does.
98+
%passes inline
99+
:pp
100+
xs = for i:(Either (Fin 3) (Fin 4)).
101+
case i of
102+
Left k -> 1
103+
Right k -> 2
104+
for j:(Fin 3). xs[Left j]
105+
-- CHECK: === inline ===
106+
-- CHECK: for
107+
-- CHECK: case
108+
-- CHECK: === inline ===
109+
-- CHECK: for
110+
-- CHECK-NOT: case
111+
112+
-- CHECK-LABEL: Inlining carries out the case-of-case optimization
113+
"Inlining carries out the case-of-case optimization"
114+
115+
-- Before inlining there are two cases, but attempting to inline `x`
116+
-- reveals a case-of-case opprtunity, which in turn exposes
117+
-- case-of-known-constructor in each branch, leading to just one case
118+
-- in the end.
119+
%passes inline
120+
:pp
121+
x = if id'(3) > 2
122+
then Just 4
123+
else Nothing
124+
case x of
125+
Just a -> a * a
126+
Nothing -> 0
127+
-- CHECK: === inline ===
128+
-- CHECK: case
129+
-- CHECK-NOT: case
130+
-- CHECK: === inline ===

0 commit comments

Comments
 (0)