Skip to content

Commit 756e381

Browse files
AdUhTkJmmengzhuo
authored andcommitted
Bugfix for register spilling
1 parent 1c3d00f commit 756e381

File tree

8 files changed

+64
-77
lines changed

8 files changed

+64
-77
lines changed

src/label.ml

Lines changed: 22 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,66 +19,28 @@ module Label = struct
1919
}
2020

2121
include struct
22-
let _ = fun (_ : t) -> ()
23-
24-
let sexp_of_t =
25-
(fun { name = name__002_; stamp = stamp__004_ } ->
26-
let bnds__001_ = ([] : _ Stdlib.List.t) in
27-
let bnds__001_ =
28-
let arg__005_ = Moon_sexp_conv.sexp_of_int stamp__004_ in
29-
(S.List [ S.Atom "stamp"; arg__005_ ] :: bnds__001_ : _ Stdlib.List.t)
30-
in
31-
let bnds__001_ =
32-
let arg__003_ = Moon_sexp_conv.sexp_of_string name__002_ in
33-
(S.List [ S.Atom "name"; arg__003_ ] :: bnds__001_ : _ Stdlib.List.t)
34-
in
35-
S.List bnds__001_
36-
: t -> S.t)
37-
;;
38-
39-
let _ = sexp_of_t
40-
41-
let equal =
42-
(fun a__006_ b__007_ ->
43-
if Stdlib.( == ) a__006_ b__007_
44-
then true
45-
else Stdlib.( = ) (a__006_.stamp : int) b__007_.stamp
46-
: t -> t -> bool)
47-
;;
48-
49-
let _ = equal
50-
51-
let (hash_fold_t : Ppx_base.state -> t -> Ppx_base.state) =
52-
fun hsv arg ->
53-
let hsv =
54-
let hsv = hsv in
55-
hsv
56-
in
57-
Ppx_base.hash_fold_int hsv arg.stamp
58-
;;
59-
60-
let _ = hash_fold_t
61-
62-
let (hash : t -> Ppx_base.hash_value) =
63-
let func arg =
64-
Ppx_base.get_hash_value
65-
(let hsv = Ppx_base.create () in
66-
hash_fold_t hsv arg)
67-
in
68-
fun x -> func x
69-
;;
70-
71-
let _ = hash
72-
73-
let compare =
74-
(fun a__008_ b__009_ ->
75-
if Stdlib.( == ) a__008_ b__009_
76-
then 0
77-
else Stdlib.compare (a__008_.stamp : int) b__009_.stamp
78-
: t -> t -> int)
79-
;;
80-
81-
let _ = compare
22+
let sexp_of_t { name; stamp } =
23+
S.List [
24+
S.List [ S.Atom "name"; Moon_sexp_conv.sexp_of_string name ];
25+
S.List [ S.Atom "stamp"; Moon_sexp_conv.sexp_of_int stamp ]]
26+
27+
let equal a b =
28+
if a == b then true
29+
else a.name = b.name && a.stamp = b.stamp
30+
31+
let hash_fold_t hsv arg =
32+
let hsv =
33+
Ppx_base.hash_fold_string hsv arg.name
34+
in Ppx_base.hash_fold_int hsv arg.stamp
35+
36+
let hash arg =
37+
Ppx_base.get_hash_value (hash_fold_t (Ppx_base.create ()) arg)
38+
39+
let compare a b =
40+
if a == b then 0
41+
else if a.name <> b.name then
42+
Stdlib.compare a.name b.name
43+
else Stdlib.compare a.stamp b.stamp
8244
end
8345
end
8446

src/riscv_generate.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -868,9 +868,9 @@ let rec do_convert tac (expr: Mcore.expr) =
868868
let old_vars = !loop_vars in
869869

870870
(* Get the labels *)
871-
let loop = Printf.sprintf "%s_%d" label.name label.stamp in
872-
let before = Printf.sprintf "before_%s" loop in
873-
let exit = Printf.sprintf "exit_%s" loop in
871+
let loop = Printf.sprintf "loophead_%s_%d" label.name label.stamp in
872+
let before = Printf.sprintf "loopbefore_%s" loop in
873+
let exit = Printf.sprintf "loopexit_%s" loop in
874874

875875
Vec.push tac (Jump before);
876876

@@ -910,7 +910,7 @@ let rec do_convert tac (expr: Mcore.expr) =
910910
List.iter2 (fun rd rs -> Vec.push tac (Assign { rd; rs })) !loop_vars results;
911911

912912
(* Jump back to the beginning of the loop. *)
913-
let loop_name = Printf.sprintf "%s_%d" label.name label.stamp in
913+
let loop_name = Printf.sprintf "loophead_%s_%d" label.name label.stamp in
914914
Vec.push tac (Jump loop_name);
915915
unit
916916

src/riscv_opt.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ let exit_fn = Hashtbl.create 256
3131
let (params: (string, var list) Hashtbl.t) = Hashtbl.create 256
3232

3333
(** Get the basic block with label `name`. *)
34-
let block_of name = Hashtbl.find basic_blocks name
34+
let block_of name =
35+
match Hashtbl.find_opt basic_blocks name with
36+
| None -> failwith (Printf.sprintf "riscv_opt.ml: unknown basic block: %s" name)
37+
| Some x -> x
3538

3639
(** Get the body of a basic block. *)
3740
let body_of name = (block_of name).body |> Vec.to_list

src/riscv_reg.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ module Imm = struct
176176
let to_string imm =
177177
match imm with
178178
| IntImm i -> string_of_int i
179+
| Int64Imm i -> Int64.to_string i
179180
| FloatImm f -> string_of_float f
180181
;;
181182
end

src/riscv_reg_spill.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ let rpo : RPO.t ref = ref RPO.empty
4040

4141
(* Spill environment: used to build the basic environment, etc. *)
4242
module SpillEnv = struct
43+
(**
44+
Data structure for data-flow analysis.
45+
Here W stands for working registers (those aren't spilled);
46+
S stands for spilled registers.
47+
*)
4348
type spill_info =
4449
{ entryW : SlotSet.t
4550
; exitW : SlotSet.t
@@ -145,7 +150,7 @@ let compute_every_inst_nextuse (bl : VBlockLabel.t) : int SlotMap.t Vec.t =
145150
let block = VProg.get_block !vprog bl in
146151
let b_liveinfo = Liveness.get_liveinfo !live_info bl in
147152
let n = Vec.length block.body in
148-
let nextUse = Vec.make ~dummy:SlotMap.empty (n + 1) in
153+
let nextUse = Vec.of_list (List.init (n + 1) (fun _ -> SlotMap.empty)) in
149154

150155
(* 1. Initialization *)
151156
SlotMap.iter b_liveinfo.exitNextUse (fun var dist ->
@@ -439,7 +444,7 @@ let apply_min_algorithm (bl : VBlockLabel.t) (nextUse : int SlotMap.t Vec.t) =
439444

440445
(*3. Reload/Spill to be inserted before each instruction, including before Term, so it is n+1 *)
441446
let body_size = Vec.length block.body in
442-
let addInsts = Vec.make ~dummy:(Vec.empty ()) (body_size + 1) in
447+
let addInsts = Vec.of_list (List.init (body_size + 1) (fun _ -> Vec.empty ())) in
443448

444449
(* 4. Common apply_inner function, but for clarity, the function to adjust k is passed in *)
445450
let apply_inner
@@ -472,7 +477,8 @@ let apply_min_algorithm (bl : VBlockLabel.t) (nextUse : int SlotMap.t Vec.t) =
472477
SlotSet.iter dests (fun var -> w := SlotSet.add !w var);
473478
let protected = dests in
474479
(* At this point, protected protects the registers being defined *)
475-
let _ = limit_func nextUse w s spill protected (i + 1) adjust_k in
480+
if i <> body_size then
481+
limit_func nextUse w s spill protected (i + 1) adjust_k;
476482

477483
(* e. Insert reload/spill instructions *)
478484
SlotSet.iter !reload (fun var -> Vec.push addInsts.![i] (Inst.generate_reload var));
@@ -580,9 +586,10 @@ let spill_reload_func (f_label : VFuncLabel.t) (func : VFunc.t) =
580586
;;
581587

582588
(* Main function: used to handle the entire program *)
583-
let spill_regs (vprog_in : VProg.t) (rpo : RPO.t) =
589+
let spill_regs (vprog_in : VProg.t) (rpo_arg : RPO.t) =
584590
vprog := vprog_in;
585-
live_info := Liveness.liveness_analysis !vprog rpo;
591+
live_info := Liveness.liveness_analysis !vprog rpo_arg;
592+
rpo := rpo_arg;
586593
VFuncMap.iter !vprog.funcs spill_reload_func;
587594
()
588595
;;

src/riscv_reg_util.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ module RPO = struct
5757
let get_func_rpo (funn : VFuncLabel.t) (rpo : t) : VBlockLabel.t list =
5858
match VFuncMap.find_opt rpo funn with
5959
| Some x -> x
60-
| None -> failwith "RPO.get_func_rpo: function not found"
60+
| None -> failwith
61+
(Printf.sprintf "riscv_reg_util.ml: RPO not found for label %s_%d" funn.name funn.stamp)
6162
;;
6263

6364
let empty : t = VFuncMap.empty

src/riscv_virtasm_generate.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,14 @@ let gen_fn (f: fn) =
358358
Hashtbl.add vblocks x {
359359
body; term = !term; preds =
360360
Vec.to_list block.pred
361-
|> List.map (fun x -> VBlock.NormalEdge (label_of x))
361+
|> List.map (fun pred ->
362+
(* The labels are fixed for loops in `riscv_generate.ml`. *)
363+
if String.starts_with ~prefix:"loophead_" x
364+
&& not (String.starts_with ~prefix:"loopbefore_" pred) then
365+
VBlock.LoopBackEdge (label_of x)
366+
else
367+
VBlock.NormalEdge (label_of pred)
368+
)
362369
}
363370
) blocks;
364371
()
@@ -375,13 +382,13 @@ let virtasm_of_ssa (ssa : Riscv_ssa.t list) =
375382
| ExtArray arr -> gen_extarr arr
376383
| _ -> failwith "riscv_virtasm_generate.ml: bad toplevel SSA") ssa;
377384

378-
let funcs = Label.Map.add_list (
385+
let funcs = Label.Map.of_list (
379386
Vec.to_list vfuncs |> List.map (fun (x: VFunc.t) -> (x.funn, x))
380-
) Label.Map.empty in
387+
) in
381388

382-
let blocks = Label.Map.add_list (
389+
let blocks = Label.Map.of_list (
383390
Hashtbl.to_seq vblocks |> List.of_seq |> List.map (fun (k, v) -> (label_of k, v))
384-
) Label.Map.empty in
391+
) in
385392

386393

387394
let out = Printf.sprintf "%s.vasm" !Driver_config.Linkcore_Opt.output_file in

test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ def try_remove(path):
6868

6969
for src in cases:
7070
print(f"Execute task: {src}")
71-
71+
72+
# Remove files from last time
73+
try_remove(f"build/{target}.s.ir")
74+
try_remove(f"build/{target}.s.ssa")
75+
try_remove(f"build/{target}.s-no-opt.ssa")
76+
try_remove(f"build/{target}.s.vasm")
77+
7278
# Note build-package is ignorant of target. It builds to a common IR.
7379
os.system(f"moonc build-package src/{src}/{src}.mbt -is-main -std-path {bundled} -o build/{src}.core")
7480

0 commit comments

Comments
 (0)