Skip to content

Commit 209a90d

Browse files
committed
[LayoutOpt]: Add a greedy heuristic to ameliorate solver time in lieu of runtime performance.
1 parent 540b2b7 commit 209a90d

File tree

5 files changed

+125
-35
lines changed

5 files changed

+125
-35
lines changed

gibbon-compiler/src/Gibbon/Compiler.hs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ passes config@Config{dynflags} l0 = do
601601
should_fuse = gopt Opt_Fusion dynflags
602602
opt_layout_local = gopt Opt_Layout_Local dynflags
603603
opt_layout_global = gopt Opt_Layout_Global dynflags
604+
use_solver = gopt Opt_Layout_Use_Solver dynflags
604605
tcProg3 = L3.tcProg isPacked
605606
l0 <- go "freshen" freshNames l0
606607
l0 <- goE0 "typecheck" L0.tcProg l0
@@ -644,12 +645,12 @@ passes config@Config{dynflags} l0 = do
644645
-- Note: L1 -> L2
645646
l1 <- if opt_layout_local
646647
then do
647-
after_layout_out <- goE1 "optimizeADTLayoutLocal" locallyOptimizeDataConLayout l1
648+
after_layout_out <- goE1 "optimizeADTLayoutLocal" (locallyOptimizeDataConLayout (not use_solver)) l1
648649
flatten_after_opt <- goE1 "L1.flatten2" flattenL1 after_layout_out
649650
pure flatten_after_opt
650651
else if opt_layout_global
651652
then do
652-
after_layout_out <- goE1 "optimizeADTLayoutGlobal" globallyOptimizeDataConLayout l1
653+
after_layout_out <- goE1 "optimizeADTLayoutGlobal" (globallyOptimizeDataConLayout (not use_solver)) l1
653654
flatten_after_opt <- goE1 "L1.flatten2" flattenL1 after_layout_out
654655
pure flatten_after_opt
655656
else return l1

gibbon-compiler/src/Gibbon/DynFlags.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ data GeneralFlag
4141
| Opt_SimpleWriteBarrier -- ^ Disables eliminate-indirection-chains optimization.
4242
| Opt_Layout_Local -- ^ Optimize the layout of Algebraic data types locally
4343
| Opt_Layout_Global -- ^ Optimize the layout of Algebraic data types globally
44+
| Opt_Layout_Use_Solver -- ^ Use the Solver to optimize the layout of the data types.
4445
deriving (Show,Read,Eq,Ord)
4546

4647
-- | Exactly like GHC's ddump flags.
@@ -118,7 +119,8 @@ dynflagsParser = DynFlags <$> (S.fromList <$> many gflagsParser) <*> (S.fromList
118119
flag' Opt_NoEagerPromote (long "no-eager-promote" <> help "Disable eager promotion.") <|>
119120
flag' Opt_SimpleWriteBarrier (long "simple-write-barrier" <> help "Disables eliminate-indirection-chains optimization.") <|>
120121
flag' Opt_Layout_Local (long "opt-layout-local" <> help "Optimizes the Layout of Algebraic data types locally") <|>
121-
flag' Opt_Layout_Global (long "opt-layout-global" <> help "Optimizes the Layout of Algebraic data types globally")
122+
flag' Opt_Layout_Global (long "opt-layout-global" <> help "Optimizes the Layout of Algebraic data types globally") <|>
123+
flag' Opt_Layout_Use_Solver (long "opt-layout-use-solver" <> help "Use the solver instead of a Greedy Heuristic")
122124

123125

124126
dflagsParser :: Parser DebugFlag

gibbon-compiler/src/Gibbon/Passes/AccessPatternsAnalysis.hs

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Gibbon.Passes.AccessPatternsAnalysis
22
( generateAccessGraphs,
3+
getGreedyOrder,
34
FieldMap,
45
DataConAccessMap,
56
)
@@ -57,7 +58,7 @@ generateAccessGraphs
5758
topologicallySortedNodes =
5859
P.map nodeFromVertex topologicallySortedVertices
5960
map = backtrackVariablesToDataConFields topologicallySortedNodes dcons
60-
edges =
61+
edges = S.toList $ S.fromList $
6162
( constructFieldGraph
6263
Nothing
6364
nodeFromVertex
@@ -69,9 +70,71 @@ generateAccessGraphs
6970
dcons
7071
accessMapsList = zipWith (\x y -> (x, y)) [dcons] [edges]
7172
accessMaps = M.fromList accessMapsList
72-
in M.insert funName accessMaps fieldMap --dbgTraceIt (sdoc (edges, map))
73+
in M.insert funName accessMaps fieldMap --dbgTraceIt (sdoc topologicallySortedVertices) dbgTraceIt ("\n") dbgTraceIt (sdoc (topologicallySortedVertices, edges)) dbgTraceIt ("\n")
7374
Nothing -> error "generateAccessGraphs: no CFG for function found!"
7475

76+
77+
78+
getGreedyOrder :: [((Integer, Integer), Integer)] -> Int -> [Integer]
79+
getGreedyOrder edges fieldLength =
80+
if edges == []
81+
then P.map P.toInteger [0 .. (fieldLength - 1)]
82+
else
83+
let partial_order = greedyOrderOfVertices edges
84+
completeOrder = P.foldl (\lst val -> if S.member val (S.fromList lst) then lst
85+
else lst ++ [val]
86+
) partial_order [0 .. (fieldLength - 1)]
87+
in dbgTraceIt (sdoc (edges, completeOrder)) P.map P.toInteger completeOrder
88+
89+
greedyOrderOfVertices :: [((Integer, Integer), Integer)] -> [Int]
90+
greedyOrderOfVertices ee = let edges' = P.map (\((a, b), c) -> ((P.fromIntegral a, P.fromIntegral b), P.fromIntegral c)) ee
91+
bounds = (\e -> let v = P.foldr (\((i, j), _) s -> S.insert j (S.insert i s)) S.empty e
92+
mini = minimum v
93+
maxi = maximum v
94+
in (mini, maxi)
95+
) edges'
96+
edgesWithoutWeight = P.map fst edges'
97+
graph = buildG bounds edgesWithoutWeight
98+
weightMap = P.foldr (\(e, w) mm -> M.insert e w mm) M.empty edges'
99+
v'' = greedyOrderOfVerticesHelper graph (topSort graph) weightMap S.empty
100+
in v'' -- dbgTraceIt (sdoc (v'', (M.elems weightMap)))
101+
102+
103+
greedyOrderOfVerticesHelper :: Graph -> [Int] -> M.Map (Int, Int) Int -> S.Set Int -> [Int]
104+
greedyOrderOfVerticesHelper graph vertices' weightMap visited = case vertices' of
105+
[] -> []
106+
x:xs -> if S.member x visited
107+
then greedyOrderOfVerticesHelper graph xs weightMap visited
108+
else let successors = reachable graph x
109+
removeCurr = S.toList $ S.delete x (S.fromList successors)
110+
orderedSucc = orderedSuccsByWeight removeCurr x weightMap visited
111+
visited' = P.foldr S.insert S.empty orderedSucc
112+
v'' = greedyOrderOfVerticesHelper graph xs weightMap visited'
113+
in if successors == [x]
114+
then orderedSucc ++ v'' --dbgTraceIt (sdoc (v'', orderedSucc))
115+
else [x] ++ orderedSucc ++ v''
116+
117+
orderedSuccsByWeight :: [Int] -> Int -> M.Map (Int, Int) Int -> S.Set Int -> [Int]
118+
orderedSuccsByWeight s i weightMap visited = case s of
119+
[] -> []
120+
_ -> let vertexWithMaxWeight = P.foldr (\v (v', maxx) -> let w = M.lookup (i, v) weightMap
121+
in case w of
122+
Nothing -> (-1, -1)
123+
Just w' -> if w' > maxx
124+
then (v, w')
125+
else (v', maxx)
126+
) (-1, -1) s
127+
in if fst vertexWithMaxWeight == -1
128+
then []
129+
else
130+
let removeVertexWithMaxWeight = S.toList $ S.delete (fst vertexWithMaxWeight) (S.fromList s)
131+
in if S.member (fst vertexWithMaxWeight) visited
132+
then orderedSuccsByWeight removeVertexWithMaxWeight i weightMap visited
133+
else fst vertexWithMaxWeight : orderedSuccsByWeight removeVertexWithMaxWeight i weightMap visited --dbgTraceIt (sdoc (s, removeVertexWithMaxWeight, vertexWithMaxWeight))
134+
135+
136+
137+
75138
backtrackVariablesToDataConFields ::
76139
(FreeVars (e l d), Ord l, Ord d, Ord (e l d), Out d, Out l) =>
77140
[(((PreExp e l d), Integer), Integer, [Integer])] ->
@@ -81,9 +144,9 @@ backtrackVariablesToDataConFields graph dcon =
81144
case graph of
82145
[] -> M.empty
83146
x : xs ->
84-
let newMap = processVertex graph x M.empty dcon
147+
let newMap = processVertex graph x M.empty dcon
85148
mlist = M.toList (newMap)
86-
m = backtrackVariablesToDataConFields xs dcon
149+
m = backtrackVariablesToDataConFields xs dcon
87150
mlist' = M.toList m
88151
newMap' = M.fromList (mlist ++ mlist')
89152
in newMap'
@@ -93,21 +156,21 @@ processVertex ::
93156
[(((PreExp e l d), Integer), Integer, [Integer])] ->
94157
(((PreExp e l d), Integer), Integer, [Integer]) ->
95158
VariableMap ->
96-
DataCon ->
159+
DataCon ->
97160
VariableMap
98161
processVertex graph node map dataCon =
99162
case node of
100163
((expression, likelihood), id, succ) ->
101164
case expression of
102165
DataConE loc dcon args ->
103166
if dcon == dataCon
104-
then
167+
then
105168
let freeVariables = L.concat (P.map (\x -> S.toList (gFreeVars x)) args)
106169
maybeIndexes = P.map (getDataConIndexFromVariable graph) freeVariables
107170
mapList = M.toList map
108171
newMapList = P.zipWith (\x y -> (x, y)) freeVariables maybeIndexes
109172
in M.fromList (mapList ++ newMapList)
110-
else map
173+
else map
111174
_ -> map
112175

113176
getDataConIndexFromVariable ::

gibbon-compiler/src/Gibbon/Passes/OptimizeADTLayout.hs

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
{-# HLINT ignore "Redundant lambda" #-}
55
{-# HLINT ignore "Use tuple-section" #-}
66
module Gibbon.Passes.OptimizeADTLayout
7-
( optimizeADTLayout,
7+
(
88
globallyOptimizeDataConLayout,
99
locallyOptimizeDataConLayout
1010
)
@@ -34,6 +34,7 @@ import Gibbon.Passes.AccessPatternsAnalysis
3434
( DataConAccessMap,
3535
FieldMap,
3636
generateAccessGraphs,
37+
getGreedyOrder
3738
)
3839
import Gibbon.Passes.CallGraph
3940
( ProducersMap (..),
@@ -66,11 +67,11 @@ import Gibbon.Passes.Flatten (flattenL1)
6667
type FieldOrder = M.Map DataCon [Integer]
6768

6869
-- TODO: Make FieldOrder an argument passed to shuffleDataCon function.
69-
optimizeADTLayout ::
70-
Prog1 ->
71-
PassM Prog1
72-
optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} =
73-
do
70+
--optimizeADTLayout ::
71+
-- Prog1 ->
72+
-- PassM Prog1
73+
--optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} =
74+
--do
7475
-- let list_pair_func_dcon =
7576
-- concatMap ( \fn@(FunDef {funName, funMeta = FunMeta {funOptLayout = layout}}) ->
7677
-- case layout of
@@ -124,25 +125,25 @@ optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} =
124125
-- p
125126
-- pure prg'
126127
--prg' <- runUntilFixPoint prg
127-
globallyOptimizeDataConLayout prg
128+
--globallyOptimizeDataConLayout prg
128129
--pure prg'
129130
--generateCopyFunctionsForFunctionsThatUseOptimizedVariable (toVar funcName) (dcon ++ "Optimized") fieldorder prg'
130131
--_ -> error "OptimizeFieldOrder: handle user constraints"
131132

132133

133-
locallyOptimizeDataConLayout :: Prog1 -> PassM Prog1
134-
locallyOptimizeDataConLayout prg1 = do
135-
runUntilFixPoint prg1
134+
locallyOptimizeDataConLayout :: Bool -> Prog1 -> PassM Prog1
135+
locallyOptimizeDataConLayout useGreedy prg1 = do
136+
runUntilFixPoint useGreedy prg1
136137

137138

138139

139-
runUntilFixPoint :: Prog1 -> PassM Prog1
140-
runUntilFixPoint prog1 = do
141-
prog1' <- producerConsumerLayoutOptimization prog1
140+
runUntilFixPoint :: Bool -> Prog1 -> PassM Prog1
141+
runUntilFixPoint useGreedy prog1 = do
142+
prog1' <- producerConsumerLayoutOptimization prog1 useGreedy
142143
prog1'' <- flattenL1 prog1'
143144
if prog1 == prog1''
144145
then return prog1
145-
else runUntilFixPoint prog1''
146+
else runUntilFixPoint useGreedy prog1''
146147

147148

148149
dataConsInFunBody :: Exp1 -> S.Set DataCon
@@ -172,8 +173,8 @@ dataConsInFunBody funBody = case funBody of
172173
MapE {} -> error "getGeneratedVariable: TODO MapE"
173174
FoldE {} -> error "getGeneratedVariable: TODO FoldE"
174175

175-
producerConsumerLayoutOptimization :: Prog1 -> PassM Prog1
176-
producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do
176+
producerConsumerLayoutOptimization :: Prog1 -> Bool -> PassM Prog1
177+
producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} useGreedy = do
177178
-- TODO: make a custom function name printer that guarantees that functions starting with _ are auto-generated.
178179
let funsToOptimize = P.concatMap (\FunDef{funName} -> ([funName | not $ isInfixOf "_" (fromVar funName)])
179180
) $ M.elems fundefs
@@ -193,7 +194,7 @@ producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do
193194
Just x -> x
194195
Nothing -> error "producerConsumerLayoutOptimization: expected a function definition!!"
195196
let fieldOrder = getAccessGraph f dcon
196-
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder
197+
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder useGreedy
197198
case result of
198199
Nothing -> pure pr --dbgTraceIt (sdoc (result, fname, fieldOrder))
199200
Just (ddefs', fundef', fieldorder) -> let fundefs' = M.delete fname fds
@@ -207,8 +208,8 @@ producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do
207208
P.foldrM lambda prg linearizeDcons --dbgTraceIt (sdoc linearizeDcons)
208209

209210

210-
globallyOptimizeDataConLayout :: Prog1 -> PassM Prog1
211-
globallyOptimizeDataConLayout prg@Prog{ddefs, fundefs, mainExp} = do
211+
globallyOptimizeDataConLayout :: Bool -> Prog1 -> PassM Prog1
212+
globallyOptimizeDataConLayout useGreedy prg@Prog{ddefs, fundefs, mainExp} = do
212213
-- TODO: make a custom function name printer that guarantees that functions starting with _ are auto-generated.
213214
let funsToOptimize = P.concatMap (\FunDef{funName} -> ([funName | not $ isInfixOf "_" (fromVar funName)])
214215
) $ M.elems fundefs
@@ -261,7 +262,7 @@ globallyOptimizeDataConLayout prg@Prog{ddefs, fundefs, mainExp} = do
261262
let fd = case maybeFd of
262263
Just x -> x
263264
Nothing -> error "globallyOptimizeDataConLayout: expected a function definition!!"
264-
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder
265+
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder useGreedy
265266
case result of
266267
Nothing -> pure pr
267268
Just (ddefs', fundef', fieldorder) -> let fundefs' = M.delete fname fds
@@ -491,12 +492,16 @@ getAccessGraph
491492

492493

493494

495+
496+
-- getGreedyFieldOrder :: Int -> DataCon -> FieldMap
497+
494498
optimizeFunctionWRTDataCon ::
495499
DDefs1 ->
496500
FunDef1 ->
497501
DataCon ->
498502
DataCon ->
499503
FieldMap ->
504+
Bool ->
500505
Maybe (DDefs1, FunDef1, FieldOrder)
501506
optimizeFunctionWRTDataCon
502507
ddefs
@@ -508,7 +513,9 @@ optimizeFunctionWRTDataCon
508513
}
509514
datacon
510515
newDcon
511-
fieldMap =
516+
fieldMap
517+
useGreedy = case useGreedy of
518+
False ->
512519
let field_len = P.length $ snd . snd $ lkp' ddefs datacon
513520
fieldorder =
514521
optimizeDataConOrderFunc
@@ -531,7 +538,24 @@ optimizeFunctionWRTDataCon
531538
fundef' = shuffleDataConFunBody True fieldorder fundef newDcon
532539
in Just (newDDefs, fundef', fieldorder) --dbgTraceIt (sdoc order) -- dbgTraceIt (sdoc fieldorder)
533540
_ -> error "more than one"
534-
541+
True ->
542+
let field_len = P.length $ snd . snd $ lkp' ddefs datacon
543+
edges' = case (M.lookup funName fieldMap) of
544+
Just d -> case (M.lookup datacon d) of
545+
Nothing -> error ""
546+
Just e -> e
547+
Nothing -> error ""
548+
greedy_order = getGreedyOrder edges' field_len
549+
fieldorder = M.insert datacon greedy_order M.empty
550+
in case M.toList fieldorder of
551+
[] -> Nothing --dbgTraceIt (sdoc fieldorder) dbgTraceIt (sdoc greedy_order)
552+
[(dcon, order)] -> let orignal_order = [0..(P.length order - 1)]
553+
in if orignal_order == P.map P.fromInteger order
554+
then Nothing
555+
else let newDDefs = optimizeDataCon (dcon, order) ddefs newDcon
556+
fundef' = shuffleDataConFunBody True fieldorder fundef newDcon
557+
in Just (newDDefs, fundef', fieldorder) --dbgTraceIt (sdoc order) -- dbgTraceIt (sdoc fieldorder) dbgTraceIt (sdoc greedy_order)
558+
_ -> error "more than one"
535559

536560
changeCallNameInRecFunction ::
537561
Var -> FunDef1 -> FunDef1

gibbon-compiler/tests/test-gibbon-examples.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,21 +896,21 @@ tests:
896896
run-modes: ["gibbon2", "gibbon3", "pointer"]
897897

898898
- name: layout1ContentSearchRunPipeline.hs
899-
test-flags: ["--no-gc", "--opt-layout-local"]
899+
test-flags: ["--no-gc", "--opt-layout-local", "--opt-layout-use-solver"]
900900
dir: examples/layout_bench
901901
answer-file: examples/layout_bench/layout1ContentSearchRunPipeline.ans
902902
failing: [interp1,pointer,gibbon1, gibbon3]
903903
run-modes: ["gibbon2"]
904904

905905
- name: manyFuncs.hs
906-
test-flags: ["--no-gc", "--opt-layout-local"]
906+
test-flags: ["--no-gc", "--opt-layout-local", "--opt-layout-use-solver"]
907907
dir: examples/layout_bench
908908
answer-file: examples/layout_bench/manyFuncsLocal.ans
909909
failing: [interp1,pointer,gibbon1, gibbon3]
910910
run-modes: ["gibbon2"]
911911

912912
- name: manyFuncs.hs
913-
test-flags: ["--no-gc", "--opt-layout-global"]
913+
test-flags: ["--no-gc", "--opt-layout-global", "--opt-layout-use-solver"]
914914
dir: examples/layout_bench
915915
answer-file: examples/layout_bench/manyFuncsGlobal.ans
916916
failing: [interp1,pointer,gibbon1, gibbon3]

0 commit comments

Comments
 (0)