1+ (* *
2+ Does escape analysis, and put heap allocations to stack allocation / registers
3+ based on the result.
4+ *)
5+ open Riscv_ssa
6+ open Riscv_opt
7+
8+ type escape_state =
9+ | NoEscape (* Does not escape the function *)
10+ | LocalEscape (* Escapes by getting captured by some closure *)
11+ | GlobalEscape (* Escapes by storing into some place *)
12+
13+ let join s1 s2 = match (s1, s2) with
14+ | GlobalEscape , _ | _ , GlobalEscape -> GlobalEscape
15+ | LocalEscape , _ | _ , LocalEscape -> LocalEscape
16+ | _ -> NoEscape
17+
18+ let print_escape =
19+ Hashtbl. iter (fun var state -> Printf. printf " %s: %s\n " var (match state with
20+ | NoEscape -> " no escape"
21+ | LocalEscape -> " local escape"
22+ | GlobalEscape -> " global escape" ))
23+
24+ let get_escape table (var : string ) =
25+ if not (Hashtbl. mem table var) then
26+ Hashtbl. add table var NoEscape ;
27+ Hashtbl. find table var
28+
29+
30+ (* *
31+ Does escape analysis.
32+ This does not yet support analysis of LocalEscape;
33+ every variable is categorized into either No- or GlobalEscape.
34+ *)
35+ let escape_analysis fn =
36+ (* Do escape analysis in the data-flow way. *)
37+ (* It's quite similar to liveness analysis in riscv_opt.ml. *)
38+ let escape_in = Hashtbl. create 1024 in
39+ let escape_out = Hashtbl. create 1024 in
40+
41+ let blocks = get_blocks fn in
42+ List. iter (fun name ->
43+ Hashtbl. add escape_in name (Hashtbl. create 64 );
44+ Hashtbl. add escape_out name (Hashtbl. create 64 );
45+ ) blocks;
46+
47+ let worklist = Basic_vec. of_list blocks in
48+ while Basic_vec. length worklist != 0 do
49+ let name = Basic_vec. pop worklist in
50+ let block = block_of name in
51+
52+ (* Escape_in should be the union of all escape_out *)
53+ Basic_vec. iter (fun pred ->
54+ let pred_out = Hashtbl. find escape_out pred in
55+ let block_in = Hashtbl. find escape_in name in
56+ Hashtbl. iter (fun var state ->
57+ let existing = get_escape block_in var in
58+ Hashtbl. replace block_in var (join existing state)
59+ ) pred_out
60+ ) block.pred;
61+
62+ (* Now calculate escape_out based on it *)
63+ let old_out = Hashtbl. find escape_out name in
64+ let last_out = ref old_out in
65+ let new_out = Hashtbl. copy old_out in
66+ let changed = ref true in
67+
68+ let replace var state =
69+ Hashtbl. replace new_out var.name state
70+ in
71+
72+ while ! changed do
73+ changed := false ;
74+ Basic_vec. iter (fun x -> match x with
75+ | Assign { rd; rs } ->
76+ replace rd (get_escape new_out rs.name)
77+
78+ | AssignLabel { rd; _ } -> replace rd GlobalEscape
79+ | Return x -> replace x GlobalEscape
80+
81+ | Call { rd; args }
82+ | CallExtern { rd; args } ->
83+ List. iter (fun arg ->
84+ replace arg GlobalEscape
85+ ) args;
86+ replace rd GlobalEscape
87+
88+ | Store { rd; rs }
89+ | Addi { rd; rs } ->
90+ let ed = get_escape new_out rd.name in
91+ let es = get_escape new_out rs.name in
92+ let state = join ed es in
93+
94+ replace rd state;
95+ replace rs state
96+
97+ | Add { rd; rs1; rs2 }
98+ | Sub { rd; rs1; rs2 } ->
99+ let ed = get_escape new_out rd.name in
100+ let es1 = get_escape new_out rs1.name in
101+ let es2 = get_escape new_out rs2.name in
102+ let state = (join ed (join es1 es2)) in
103+
104+ replace rd state;
105+ replace rs1 state;
106+ replace rs2 state
107+
108+ | Phi { rd; rs } ->
109+ let state =
110+ List. fold_left (fun acc (var , _ ) ->
111+ join acc (get_escape new_out var.name)
112+ ) NoEscape rs
113+ in
114+ replace rd state;
115+ List. iter (fun (var , _ ) -> replace var state) rs
116+
117+ | _ -> () ) block.body;
118+
119+ Hashtbl. iter (fun var state ->
120+ if state != get_escape ! last_out var then
121+ changed := true
122+ ) new_out;
123+ last_out := new_out;
124+ done ;
125+
126+ (* If anything changes, put it back to queue *)
127+ let changed = ref false in
128+ Hashtbl. iter (fun var state ->
129+ if state != get_escape old_out var then
130+ changed := true
131+ ) new_out;
132+
133+ (* Note this `!` does not mean not *)
134+ if ! changed then (
135+ Hashtbl. replace escape_out name new_out;
136+ Basic_vec. iter (fun x -> Basic_vec. push worklist x) block.succ
137+ )
138+ done ;
139+
140+ escape_out
141+
142+ (* * Reforms `malloc` on heap to `alloca` on stack when possible. *)
143+ let malloc_to_alloca fn =
144+ let blocks = get_blocks fn in
145+ let escape_data = escape_analysis fn in
146+ List. iter (fun name ->
147+ let block = block_of name in
148+ let body = block.body |> Basic_vec. to_list in
149+ let escaped = Hashtbl. find escape_data name in
150+ let changed = List. map (fun x -> match x with
151+ | Malloc { rd; size } ->
152+ if get_escape escaped rd.name = NoEscape then
153+ Alloca { rd; size }
154+ else
155+ Malloc { rd; size }
156+ | w -> w) body in
157+ block.body < - changed |> Basic_vec. of_list
158+ ) blocks
159+
160+ let lower_malloc ssa =
161+ iter_fn malloc_to_alloca ssa
0 commit comments