Skip to content

Commit 0df31e6

Browse files
AdUhTkJmmengzhuo
authored andcommitted
Escape analysis
1 parent 5a45282 commit 0df31e6

File tree

4 files changed

+176
-1
lines changed

4 files changed

+176
-1
lines changed

src/riscv_opt_escape.ml

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
(**
2+
Does escape analysis, and put heap allocations to stack allocation / registers
3+
based on the result.
4+
*)
5+
open Riscv_ssa
6+
open Riscv_opt
7+
8+
type escape_state =
9+
| NoEscape (* Does not escape the function *)
10+
| LocalEscape (* Escapes by getting captured by some closure *)
11+
| GlobalEscape (* Escapes by storing into some place *)
12+
13+
let join s1 s2 = match (s1, s2) with
14+
| GlobalEscape, _ | _, GlobalEscape -> GlobalEscape
15+
| LocalEscape, _ | _, LocalEscape -> LocalEscape
16+
| _ -> NoEscape
17+
18+
let print_escape =
19+
Hashtbl.iter (fun var state -> Printf.printf "%s: %s\n" var (match state with
20+
| NoEscape -> "no escape"
21+
| LocalEscape -> "local escape"
22+
| GlobalEscape -> "global escape"))
23+
24+
let get_escape table (var: string) =
25+
if not (Hashtbl.mem table var) then
26+
Hashtbl.add table var NoEscape;
27+
Hashtbl.find table var
28+
29+
30+
(**
31+
Does escape analysis.
32+
This does not yet support analysis of LocalEscape;
33+
every variable is categorized into either No- or GlobalEscape.
34+
*)
35+
let escape_analysis fn =
36+
(* Do escape analysis in the data-flow way. *)
37+
(* It's quite similar to liveness analysis in riscv_opt.ml. *)
38+
let escape_in = Hashtbl.create 1024 in
39+
let escape_out = Hashtbl.create 1024 in
40+
41+
let blocks = get_blocks fn in
42+
List.iter (fun name ->
43+
Hashtbl.add escape_in name (Hashtbl.create 64);
44+
Hashtbl.add escape_out name (Hashtbl.create 64);
45+
) blocks;
46+
47+
let worklist = Basic_vec.of_list blocks in
48+
while Basic_vec.length worklist != 0 do
49+
let name = Basic_vec.pop worklist in
50+
let block = block_of name in
51+
52+
(* Escape_in should be the union of all escape_out *)
53+
Basic_vec.iter (fun pred ->
54+
let pred_out = Hashtbl.find escape_out pred in
55+
let block_in = Hashtbl.find escape_in name in
56+
Hashtbl.iter (fun var state ->
57+
let existing = get_escape block_in var in
58+
Hashtbl.replace block_in var (join existing state)
59+
) pred_out
60+
) block.pred;
61+
62+
(* Now calculate escape_out based on it *)
63+
let old_out = Hashtbl.find escape_out name in
64+
let last_out = ref old_out in
65+
let new_out = Hashtbl.copy old_out in
66+
let changed = ref true in
67+
68+
let replace var state =
69+
Hashtbl.replace new_out var.name state
70+
in
71+
72+
while !changed do
73+
changed := false;
74+
Basic_vec.iter (fun x -> match x with
75+
| Assign { rd; rs } ->
76+
replace rd (get_escape new_out rs.name)
77+
78+
| AssignLabel { rd; _ } -> replace rd GlobalEscape
79+
| Return x -> replace x GlobalEscape
80+
81+
| Call { rd; args }
82+
| CallExtern { rd; args } ->
83+
List.iter (fun arg ->
84+
replace arg GlobalEscape
85+
) args;
86+
replace rd GlobalEscape
87+
88+
| Store { rd; rs }
89+
| Addi { rd; rs } ->
90+
let ed = get_escape new_out rd.name in
91+
let es = get_escape new_out rs.name in
92+
let state = join ed es in
93+
94+
replace rd state;
95+
replace rs state
96+
97+
| Add { rd; rs1; rs2 }
98+
| Sub { rd; rs1; rs2 } ->
99+
let ed = get_escape new_out rd.name in
100+
let es1 = get_escape new_out rs1.name in
101+
let es2 = get_escape new_out rs2.name in
102+
let state = (join ed (join es1 es2)) in
103+
104+
replace rd state;
105+
replace rs1 state;
106+
replace rs2 state
107+
108+
| Phi { rd; rs } ->
109+
let state =
110+
List.fold_left (fun acc (var, _) ->
111+
join acc (get_escape new_out var.name)
112+
) NoEscape rs
113+
in
114+
replace rd state;
115+
List.iter (fun (var, _) -> replace var state) rs
116+
117+
| _ -> ()) block.body;
118+
119+
Hashtbl.iter (fun var state ->
120+
if state != get_escape !last_out var then
121+
changed := true
122+
) new_out;
123+
last_out := new_out;
124+
done;
125+
126+
(* If anything changes, put it back to queue *)
127+
let changed = ref false in
128+
Hashtbl.iter (fun var state ->
129+
if state != get_escape old_out var then
130+
changed := true
131+
) new_out;
132+
133+
(* Note this `!` does not mean not *)
134+
if !changed then (
135+
Hashtbl.replace escape_out name new_out;
136+
Basic_vec.iter (fun x -> Basic_vec.push worklist x) block.succ
137+
)
138+
done;
139+
140+
escape_out
141+
142+
(** Reforms `malloc` on heap to `alloca` on stack when possible. *)
143+
let malloc_to_alloca fn =
144+
let blocks = get_blocks fn in
145+
let escape_data = escape_analysis fn in
146+
List.iter (fun name ->
147+
let block = block_of name in
148+
let body = block.body |> Basic_vec.to_list in
149+
let escaped = Hashtbl.find escape_data name in
150+
let changed = List.map (fun x -> match x with
151+
| Malloc { rd; size } ->
152+
if get_escape escaped rd.name = NoEscape then
153+
Alloca { rd; size }
154+
else
155+
Malloc { rd; size }
156+
| w -> w) body in
157+
block.body <- changed |> Basic_vec.of_list
158+
) blocks
159+
160+
let lower_malloc ssa =
161+
iter_fn malloc_to_alloca ssa

src/riscv_opt_gather.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ let opt ssa =
1111
for i = 1 to 3 do
1212
Riscv_opt_inline.inline ssa;
1313
Riscv_opt_peephole.peephole ssa;
14+
Riscv_opt_escape.lower_malloc ssa;
1415
done;
1516

1617
let s = map_fn ssa_of_cfg ssa in

src/riscv_ssa.ml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ and t =
216216
| ExtArray of extern_array (* An array in `.data` section *)
217217
| CallExtern of call_data (* Call a C function *)
218218
| CallIndirect of call_indirect (* Call a function pointer *)
219-
| Malloc of malloc
219+
| Malloc of malloc (* Allocate on heap *)
220+
| Alloca of malloc (* Allocate on stack *)
220221
| Return of var
221222

222223
(* Note: *)
@@ -434,6 +435,9 @@ let to_string t =
434435

435436
| Malloc { rd; size } ->
436437
Printf.sprintf "malloc %s %d" rd.name size
438+
439+
| Alloca { rd; size } ->
440+
Printf.sprintf "alloca %s %d" rd.name size
437441

438442
| FnDecl { fn; args; body; } ->
439443
let args_str = String.concat ", " (List.map (fun x -> x.name) args) in
@@ -511,6 +515,7 @@ let rec reg_map fd fs t = match t with
511515
| GlobalVarDecl var -> GlobalVarDecl var
512516
| ExtArray arr -> ExtArray arr
513517
| Malloc { rd; size } -> Malloc { rd = fd rd; size }
518+
| Alloca { rd; size } -> Alloca { rd = fd rd; size }
514519
| Return var -> Return (fs var)
515520

516521
let reg_iter fd fs t =

test/interpreter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,14 @@ int64_t interpret(std::string label) {
333333
continue;
334334
}
335335

336+
if (op == "alloca") {
337+
auto len = int_of(args[2]);
338+
339+
VAL(1) = (int64_t) alloca(len);
340+
OUTPUT(args[1], VAL(1));
341+
continue;
342+
}
343+
336344
if (op == "phi") {
337345
bool is_bad = true;
338346

0 commit comments

Comments
 (0)