Skip to content

Commit bac03f8

Browse files
committed
Teach occurrence analysis to hoist using the free variables it already computed, instead of re-traversing the term.
Pull out a little common machinery between this and DCE, where we also do this optimization. On the BFGS example, this speeds up occurrence analysis by 60-70%, and (according to the profiler) is worth something like a 7% improvement to compilation time end-to-end. Also make two minor improvements to classes relating to WriterT1 that came up while experimenting with this.
1 parent 0873300 commit bac03f8

File tree

5 files changed

+66
-41
lines changed

5 files changed

+66
-41
lines changed

src/lib/MTL1.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ instance (SinkableE w, Monoid1 w, ScopeReader m) => ScopeReader (WriterT1 w m) w
8888
getDistinct = lift11 getDistinct
8989
{-# INLINE getDistinct #-}
9090

91-
instance ( SinkableE w, HoistableE w, Monoid1 w
91+
instance ( SinkableE w, Monoid1 w
9292
, HoistableState w, EnvExtender m)
9393
=> EnvExtender (WriterT1 w m) where
9494
refreshAbs ab cont = WriterT1 \s -> do
@@ -113,6 +113,8 @@ instance MonadTrans11 (ReaderT1 r) where
113113
lift11 = ReaderT1 . lift
114114
{-# INLINE lift11 #-}
115115

116+
deriving instance MonadWriter s (m n) => MonadWriter s (ReaderT1 r m n)
117+
116118
deriving instance MonadState s (m n) => MonadState s (ReaderT1 r m n)
117119

118120
instance (SinkableE r, EnvReader m) => EnvReader (ReaderT1 r m) where

src/lib/Name.hs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2700,7 +2700,7 @@ unsafeCoerceNameSet = TrulyUnsafe.unsafeCoerce
27002700
-- XXX: there are privileged functions that depend on `HoistableE` instances
27012701
-- being correct.
27022702
class HoistableE (e::E) where
2703-
freeVarsE :: e n-> NameSet n
2703+
freeVarsE :: e n -> NameSet n
27042704
default freeVarsE :: (GenericE e, HoistableE (RepE e)) => e n -> NameSet n
27052705
freeVarsE e = freeVarsE $ fromE e
27062706

@@ -2759,6 +2759,19 @@ hoistToTop e =
27592759
[] -> HoistSuccess $ unsafeCoerceE e
27602760
leakedNames -> HoistFailure leakedNames
27612761

2762+
-- User is responsible for making sure that the `NameSet n` really is the set of
2763+
-- free variables of the `e n`.
2764+
data CachedFVs e n = UnsafeCachedFVs {
2765+
_cachedFVs :: (NameSet n), fromCachedFVs :: (e n) }
2766+
instance HoistableE e => HoistableE (CachedFVs e) where
2767+
freeVarsE (UnsafeCachedFVs fvs _) = fvs
2768+
2769+
hoistViaCachedFVs :: (BindsNames b, HoistableE e) =>
2770+
b n l -> CachedFVs e l -> HoistExcept (e n)
2771+
hoistViaCachedFVs b withFvs = case hoist b withFvs of
2772+
HoistSuccess withFvs' -> HoistSuccess $ fromCachedFVs withFvs'
2773+
HoistFailure err -> HoistFailure err
2774+
27622775
sinkFromTop :: SinkableE e => e VoidS -> e n
27632776
sinkFromTop = unsafeCoerceE
27642777
{-# INLINE sinkFromTop #-}
@@ -3293,6 +3306,13 @@ mapNameMap :: (a -> b) -> NameMap c a n -> (NameMap c b n)
32933306
mapNameMap f (UnsafeNameMap raw) = UnsafeNameMap $ fmap f raw
32943307
{-# INLINE mapNameMap #-}
32953308

3309+
keysNameMap :: NameMap c a n -> [Name c n]
3310+
keysNameMap = map fst . toListNameMap
3311+
{-# INLINE keysNameMap #-}
3312+
3313+
keySetNameMap :: (Color c) => NameMap c a n -> NameSet n
3314+
keySetNameMap nmap = freeVarsE $ ListE $ keysNameMap nmap
3315+
32963316
instance SinkableE (NameMap c a) where
32973317
sinkingProofE = undefined
32983318

@@ -3338,6 +3358,12 @@ mapNameMapE :: (e1 n -> e2 n)
33383358
mapNameMapE f (NameMapE nmap) = NameMapE $ mapNameMap f nmap
33393359
{-# INLINE mapNameMapE #-}
33403360

3361+
keysNameMapE :: NameMapE c e n -> [Name c n]
3362+
keysNameMapE (NameMapE nmap) = keysNameMap nmap
3363+
3364+
keySetNameMapE :: (Color c) => NameMapE c e n -> NameSet n
3365+
keySetNameMapE (NameMapE nmap) = keySetNameMap nmap
3366+
33413367
instance SinkableE e => SinkableE (NameMapE c e) where
33423368
sinkingProofE = undefined
33433369

src/lib/OccAnalysis.hs

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -289,18 +289,17 @@ occNest a (Abs decls ans) = case decls of
289289
\d'@(Let b' (DeclBinding _ expr')) rest -> do
290290
exprIx <- summaryExpr $ sink expr'
291291
extend b' exprIx do
292-
below <- occNest (sink a) rest
293-
checkAllFreeVariablesMentioned below
292+
(below, belowfvs) <- isolated do
293+
occNest (sink a) rest >>= wrapWithCachedFVs
294+
modify (<> belowfvs)
294295
accessInfo <- getAccessInfo $ binderName d'
295296
let usage = usageInfo accessInfo
296297
let dceAttempt = case isPureDecl of
297-
False -> ElimFailure d' usage below
298+
False -> ElimFailure d' usage $ fromCachedFVs below
298299
True ->
299-
-- Or hoistUsingCachedFVs in the monad, if we decide to do
300-
-- that optimization
301-
case hoist d' below of
300+
case hoistViaCachedFVs d' below of
302301
HoistSuccess below' -> ElimSuccess below'
303-
HoistFailure _ -> ElimFailure d' usage below
302+
HoistFailure _ -> ElimFailure d' usage $ fromCachedFVs below
304303
return dceAttempt
305304
case dceAttempt of
306305
ElimSuccess below' -> return below'
@@ -320,19 +319,21 @@ occNest a (Abs decls ans) = case decls of
320319
let binding'' = DeclBinding ann expr
321320
return $ Abs (Nest (Let b' binding'') ds'') ans''
322321

323-
checkAllFreeVariablesMentioned :: HoistableE e => e n -> OCCM n ()
324-
checkAllFreeVariablesMentioned e = do
322+
wrapWithCachedFVs :: forall e n. HoistableE e => e n -> OCCM n (CachedFVs e n)
323+
wrapWithCachedFVs e = do
324+
FV fvMap <- get
325+
let fvs = keySetNameMapE fvMap
325326
#ifdef DEX_DEBUG
326-
FV fvs <- get
327-
forM_ (nameSetToList (freeVarsE e)) \name ->
328-
case lookupNameMapE name fvs of
329-
Just _ -> return ()
330-
Nothing -> error $ "Free variable map missing free variable " ++ show name
327+
let fvsOk = map getRawName (freeVarsList e :: [SAtomName n]) == nameSetRawNames fvs
331328
#else
332-
void $ return e -- Refer to `e` in this branch to avoid a GHC warning
333-
return ()
334-
{-# INLINE checkAllFreeVariablesMentioned #-}
329+
-- Verification of this invariant defeats the performance benefits of
330+
-- avoiding the extra traversal (e.g. actually having linear complexity),
331+
-- so we only do that in debug builds.
332+
let fvsOk = True
335333
#endif
334+
case fvsOk of
335+
True -> return $ UnsafeCachedFVs fvs e
336+
False -> error $ "Free variables were computed incorrectly."
336337

337338
instance HasOCC (DeclBinding SimpIR) where
338339
occ a (DeclBinding ann expr) = do
@@ -407,14 +408,11 @@ instance HasOCC (Hof SimpIR) where
407408
modify (<> useManyTimes bodyFV)
408409
return body'
409410
RunReader ini bd -> do
410-
ini' <- occ accessOnce ini
411411
iniIx <- summary ini
412-
bd' <- oneShot a [Deterministic [], iniIx]bd
412+
bd' <- oneShot a [Deterministic [], iniIx] bd
413+
ini' <- occ accessOnce ini
413414
return $ RunReader ini' bd'
414415
RunWriter Nothing (BaseMonoid empty combine) bd -> do
415-
-- We will process the combining function when we meet it in MExtend ops
416-
-- (but we won't attempt to eliminate dead code in it).
417-
empty' <- occ accessOnce empty
418416
-- There is no way to read from the reference in a Writer, so the only way
419417
-- an indexing expression can depend on it is by referring to the
420418
-- reference itself. One way to so refer that is opaque to occurrence
@@ -428,17 +426,20 @@ instance HasOCC (Hof SimpIR) where
428426
-- different references across loop iterations are not distinguishable.
429427
-- The same argument holds for the heap parameter.
430428
bd' <- oneShot a [Deterministic [], Deterministic []] bd
429+
-- We will process the combining function when we meet it in MExtend ops
430+
-- (but we won't attempt to eliminate dead code in it).
431+
empty' <- occ accessOnce empty
431432
return $ RunWriter Nothing (BaseMonoid empty' combine) bd'
432433
RunWriter (Just _) _ _ ->
433434
error "Expecting to do occurrence analysis before destination passing."
434435
RunState Nothing ini bd -> do
435-
ini' <- occ accessOnce ini
436436
-- If we wanted to be more precise, the summary for the reference should
437437
-- be something about the stuff that might flow into the `put` operations
438438
-- affecting that reference. Using `IxAll` is a conservative
439439
-- approximation (in downstream analysis it means "assume I touch every
440440
-- value").
441-
bd' <- oneShot a [Deterministic [], IxAll]bd
441+
bd' <- oneShot a [Deterministic [], IxAll] bd
442+
ini' <- occ accessOnce ini
442443
return $ RunState Nothing ini' bd'
443444
RunState (Just _) _ _ ->
444445
error "Expecting to do occurrence analysis before destination passing."
@@ -465,23 +466,25 @@ occWithBinder
465466
-> (forall l. DExt n l => Binder SimpIR n l -> e l -> OCCM l a)
466467
-> OCCM n a
467468
occWithBinder (Abs (b:>ty) body) cont = do
468-
ty' <- occTy ty
469-
refreshAbs (Abs (b:>ty') body) cont
469+
(ty', fvs) <- isolated $ occTy ty
470+
ans <- refreshAbs (Abs (b:>ty') body) cont
471+
modify (<> fvs)
472+
return ans
470473
{-# INLINE occWithBinder #-}
471474

472475
instance HasOCC (RefOp SimpIR) where
473476
occ _ = \case
474477
MExtend (BaseMonoid empty combine) val -> do
478+
valIx <- summary val
479+
-- Treat the combining function as inlined here and called once
480+
combine' <- oneShot accessOnce [Deterministic [], valIx] combine
475481
val' <- occ accessOnce val
476-
valIx <- summary val'
477482
-- TODO(precision) The empty value of the monoid is presumably dead here,
478483
-- but we pretend like it's not to make sure that occurrence analysis
479484
-- results mention every free variable in the traversed expression. This
480485
-- may lead to missing an opportunity to inline something into the empty
481486
-- value of the given monoid, since references thereto will be overcounted.
482487
empty' <- occ accessOnce empty
483-
-- Treat the combining function as inlined here and called once
484-
combine' <- oneShot accessOnce [Deterministic [], valIx] combine
485488
return $ MExtend (BaseMonoid empty' combine') val'
486489
-- I'm pretty sure the others are all strict, and not usefully analyzable
487490
-- for what they do to the incoming access pattern.

src/lib/Occurrence.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@ instance MaxPlus UsageInfo where
727727
usageInfo :: AccessInfo n -> UsageInfo
728728
usageInfo (AccessInfo s dyn) =
729729
UsageInfo s $ approxConst $ collapse $ interp dyn
730+
{-# SCC usageInfo #-}
730731

731732
-- === Notes ===
732733

src/lib/Optimize.hs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,6 @@ dceBlock' (Abs decls ans) = do
463463
modify (<> old)
464464
return block
465465

466-
data CachedFVs e n = UnsafeCachedFVs { _cachedFVs :: (NameSet n), fromCachedFVs :: (e n) }
467-
instance HoistableE (CachedFVs e) where
468-
freeVarsE (UnsafeCachedFVs fvs _) = fvs
469-
470466
wrapWithCachedFVs :: HoistableE e => e n -> DCEM n (CachedFVs e n)
471467
wrapWithCachedFVs e = do
472468
FV fvs <- get
@@ -482,12 +478,9 @@ wrapWithCachedFVs e = do
482478
True -> return $ UnsafeCachedFVs fvs e
483479
False -> error $ "Free variables were computed incorrectly."
484480

485-
hoistUsingCachedFVs :: (BindsNames b, HoistableE e) => b n l -> e l -> DCEM l (HoistExcept (e n))
486-
hoistUsingCachedFVs b e = do
487-
ec <- wrapWithCachedFVs e
488-
return $ case hoist b ec of
489-
HoistSuccess e' -> HoistSuccess $ fromCachedFVs e'
490-
HoistFailure err -> HoistFailure err
481+
hoistUsingCachedFVs :: (BindsNames b, HoistableE e) =>
482+
b n l -> e l -> DCEM l (HoistExcept (e n))
483+
hoistUsingCachedFVs b e = hoistViaCachedFVs b <$> wrapWithCachedFVs e
491484

492485
data ElimResult n where
493486
ElimSuccess :: Abs (Nest SDecl) SAtom n -> ElimResult n

0 commit comments

Comments
 (0)