Skip to content

Commit 452c5b5

Browse files
committed
work queue
don't normalize twice
1 parent 1d5915c commit 452c5b5

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

clash-lib/clash-lib.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ Library
157157
-- TODO bounds
158158
lifted-async,
159159
lifted-base,
160+
lockfree-queue,
160161
monad-control,
161162
mtl >= 2.1.2 && < 2.3,
162163
ordered-containers >= 0.2 && < 0.3,

clash-lib/src/Clash/Core/VarEnv.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ module Clash.Core.VarEnv
4444
, emptyVarSet
4545
, unitVarSet
4646
-- ** Modification
47+
, extendVarSet
4748
, delVarSetByKey
4849
, unionVarSet
4950
, differenceVarSet

clash-lib/src/Clash/Normalize.hs

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@ import qualified Control.Concurrent.MVar.Lifted as MVar
2323
import Control.Concurrent.Supply (Supply)
2424
import Control.Exception (throw)
2525
import qualified Control.Lens as Lens
26-
import Control.Monad (when)
26+
import Control.Monad (when, unless)
2727
import qualified Control.Monad.IO.Class as Monad (liftIO)
2828
import Control.Monad.State.Strict (State)
29+
import Data.Bifunctor (bimap)
2930
import Data.Default (def)
3031
import Data.Either (lefts,partitionEithers)
32+
import Data.Foldable (traverse_)
3133
import qualified Data.HashMap.Strict as HashMap
3234
import Data.List
3335
(intersect, mapAccumL)
3436
import qualified Data.Map as Map
3537
import qualified Data.Maybe as Maybe
3638
import qualified Data.Set as Set
3739
import qualified Data.Set.Lens as Lens
40+
import qualified Data.Concurrent.Queue.MichaelScott as MS
3841

3942
#if MIN_VERSION_prettyprinter(1,7,0)
4043
import Prettyprinter (vcat)
@@ -66,8 +69,8 @@ import Clash.Core.TyCon (TyConMap)
6669
import Clash.Core.Type (isPolyTy)
6770
import Clash.Core.Var (Id, varName, varType)
6871
import Clash.Core.VarEnv
69-
(VarEnv, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv,
70-
extendVarEnv, lookupVarEnv, mapVarEnv, mapMaybeVarEnv,
72+
(VarEnv, VarSet, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv, emptyVarSet,
73+
extendVarEnv, extendVarSet, lookupVarEnv, mapVarEnv, mapMaybeVarEnv,
7174
mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv)
7275
import Clash.Debug (traceIf)
7376
import Clash.Driver.Types
@@ -150,11 +153,30 @@ runNormalization env supply globals typeTrans peEval eval rcsMap lock entities s
150153

151154
normalize :: [Id] -> NormalizeSession BindingMap
152155
normalize tops = do
153-
normBinds <- Async.mapConcurrently normalize' tops
154-
pure (mkVarEnv (concat normBinds))
155-
156-
normalize' :: Id -> NormalizeSession [(Id, Binding Term)]
157-
normalize' nm = do
156+
q <- Monad.liftIO MS.newQ
157+
traverse_ (Monad.liftIO . MS.pushL q) tops
158+
binds <- MVar.newMVar (emptyVarSet, [])
159+
-- one thread per top-level binding
160+
Async.mapConcurrently_ (\_ -> normalizeStep q binds) tops
161+
mkVarEnv . snd <$> MVar.readMVar binds
162+
163+
normalizeStep
164+
:: MS.LinkedQueue Id
165+
-> MVar (VarSet, [(Id, Binding Term)])
166+
-> NormalizeSession ()
167+
normalizeStep q binds = do
168+
res <- Monad.liftIO $ MS.tryPopR q
169+
case res of
170+
Just id' -> do
171+
(bound, _) <- MVar.readMVar binds
172+
unless (id' `elemVarSet` bound) $ do
173+
pair <- normalize' id' q
174+
MVar.modifyMVar_ binds (pure . bimap (`extendVarSet` id') (pair:))
175+
normalizeStep q binds
176+
Nothing -> pure ()
177+
178+
normalize' :: Id -> MS.LinkedQueue Id -> NormalizeSession (Id, Binding Term)
179+
normalize' nm q = do
158180
bndrsV <- Lens.use bindings
159181
exprM <- MVar.withMVar bndrsV (pure . lookupVarEnv nm)
160182
let nmS = showPpr (varName nm)
@@ -207,8 +229,8 @@ normalize' nm = do
207229

208230
-- traceM ("normalize: end: " <> nmS)
209231

210-
normChildren <- Async.mapConcurrently normalize' toNormalize
211-
return ((nm, tmNorm) : concat normChildren)
232+
traverse_ (Monad.liftIO . MS.pushL q) toNormalize
233+
pure (nm, tmNorm)
212234
else
213235
do
214236
-- Throw an error for unrepresentable topEntities and functions
@@ -230,7 +252,7 @@ normalize' nm = do
230252
, showPpr (coreTypeOf nm')
231253
, ") has a non-representable return type."
232254
, " Not normalising:\n", showPpr tm] )
233-
(return [(nm,(Binding nm' sp inl pr tm r))])
255+
(return (nm,(Binding nm' sp inl pr tm r)))
234256

235257

236258
Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found"

0 commit comments

Comments
 (0)