Skip to content

Commit c8337c8

Browse files
AdUhTkJmmengzhuo
authored andcommitted
Add match for sparse constants
1 parent 1446acd commit c8337c8

File tree

2 files changed

+149
-117
lines changed

2 files changed

+149
-117
lines changed

src/riscv_generate.ml

Lines changed: 141 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,20 @@ let deal_with_prim tac rd (prim: Primitive.prim) args =
505505
Vec.push tac (Xor { rd = temp; rs1 = a; rs2 = b });
506506
Vec.push tac (Slti { rd; rs = temp; imm = 1 })
507507

508+
(* Create a null-pointer. *)
509+
| Pnull ->
510+
Vec.push tac (AssignInt64 { rd; imm = 0L })
511+
512+
| Pis_null ->
513+
let zero = new_temp Mtype.T_bytes in
514+
Vec.push tac (AssignInt64 { rd = zero; imm = 0L });
515+
Vec.push tac (Eq { rd; rs1 = List.hd args; rs2 = zero });
516+
508517
| Ppanic ->
509518
Vec.push tac (CallExtern { rd; fn = "abort"; args })
510519

520+
(* ref.as_non_null in WASM is just a copy *)
521+
| Pas_non_null
511522
| Pidentity ->
512523
Vec.push tac (Assign { rd; rs = List.hd args })
513524

@@ -1205,132 +1216,147 @@ let rec do_convert tac (expr: Mcore.expr) =
12051216

12061217
| Cexpr_switch_constant { obj; cases; default; ty; _ } ->
12071218
let index = do_convert tac obj in
1208-
1209-
let die () =
1210-
failwith "riscv_generate.ml: bad match on constants"
1211-
in
1212-
12131219
let len = List.length cases in
12141220

12151221
if len = 0 then (
12161222
(* Only default case is present. No match needs to be generated. *)
12171223
do_convert tac default
12181224
) else (
12191225
let rd = new_temp ty in
1220-
let (const_ty, _) = List.hd cases in
1221-
(match const_ty with
1222-
| Constant.C_int { v; } ->
1223-
(* Every match case here is an int. Extract the values. *)
1224-
let values =
1225-
List.map (fun (t, _) ->
1226-
match t with
1227-
| Constant.C_int { v } -> Int32.to_int v
1228-
| _ -> die()
1229-
) cases
1230-
in
1231-
1232-
let mx = List.fold_left (fun mx x -> max mx x) (-2147483647-1) values in
1233-
let mn = List.fold_left (fun mn x -> min mn x) 2147483647 values in
1234-
1235-
(* Sparse values, generate a hash function *)
1236-
if mx - mn >= 20 then (
1237-
failwith "TODO: large"
1238-
)
1239-
1240-
(* Dense values, just get a jump table *)
1241-
else (
1242-
let table = new_label "jumptable_int_" in
1243-
let jump = new_label "do_jump_int_" in
1244-
let jumps = List.init (mx - mn + 1) (fun _ -> new_label "jumptable_int_") in
1245-
let out = new_label "jumptable_int_out_" in
1246-
let default_lbl = new_label "jumptable_default_" in
1247-
1248-
(* If the value is outside the min/max range, jump to default *)
1249-
let inrange = new_temp Mtype.T_bool in
1250-
let maximum = new_temp Mtype.T_int in
1251-
let minimum = new_temp Mtype.T_int in
1252-
let _1 = new_temp Mtype.T_bool in
1253-
let _2 = new_temp Mtype.T_bool in
1254-
1255-
(* Evaluate (x < max) && (x > min), which is the range where we can use jump table *)
1256-
Vec.push tac (AssignInt { rd = maximum; imm = mx });
1257-
Vec.push tac (AssignInt { rd = minimum; imm = mn });
1258-
Vec.push tac (Leq { rd = _1; rs1 = index; rs2 = maximum });
1259-
Vec.push tac (Geq { rd = _2; rs1 = index; rs2 = minimum });
1260-
Vec.push tac (And { rd = inrange; rs1 = _1; rs2 = _2 });
1261-
Vec.push tac (Branch { cond = inrange; ifso = jump; ifnot = default_lbl });
1262-
1263-
(* Load the address *)
1264-
Vec.push tac (Label jump);
1265-
1266-
let jtable = new_temp Mtype.T_bytes in
1267-
let ptr_sz = new_temp Mtype.T_int in
1268-
let off = new_temp Mtype.T_int in
1269-
let altered = new_temp Mtype.T_bytes in
1270-
let target = new_temp Mtype.T_bytes in
1271-
1272-
Vec.push tac (AssignLabel { rd = jtable; imm = table });
1273-
Vec.push tac (AssignInt { rd = ptr_sz; imm = pointer_size });
1274-
1275-
(* We must also minus the minimum, unlike switch_constr *)
1276-
let min_var = new_temp Mtype.T_int in
1277-
let ind_2 = new_temp Mtype.T_int in
1278-
1279-
Vec.push tac (AssignInt { rd = min_var; imm = mn });
1280-
Vec.push tac (Sub { rd = ind_2; rs1 = index; rs2 = min_var });
1281-
1282-
(* Now find which address to jump to *)
1283-
Vec.push tac (Mul { rd = off; rs1 = ind_2; rs2 = ptr_sz });
1284-
Vec.push tac (Add { rd = altered; rs1 = jtable; rs2 = off });
1285-
Vec.push tac (Load { rd = target; rs = altered; offset = 0; byte = pointer_size });
1286-
1287-
let visited = Vec.empty () in
1288-
let correspondence = Array.make (List.length cases) "_uninit" in
1289-
1290-
(* For each label, generate the code of it *)
1291-
let tac_cases = Vec.empty () in
12921226

1293-
List.iter2 (fun value (_, expr) ->
1294-
let lbl = List.nth jumps (value - mn) in
1227+
let values =
1228+
List.map (fun (t, _) ->
1229+
match t with
1230+
| Constant.C_int { v } -> Int32.to_int v
1231+
| Constant.C_char v -> Uchar.to_int v
1232+
| _ -> failwith "TODO: unsupported switch constant type"
1233+
) cases
1234+
in
12951235

1296-
Vec.push tac_cases (Label lbl);
1297-
let ret = do_convert tac_cases expr in
1298-
Vec.push tac_cases (Assign { rd; rs = ret });
1299-
Vec.push tac_cases (Jump out);
1300-
Vec.push visited value;
1301-
correspondence.(value - mn) <- lbl
1302-
) values cases;
1303-
1304-
(* For each values in the (min, max) range, redirect them into default *)
1305-
let visited = visited |> Vec.to_list in
1306-
1307-
Vec.push tac_cases (Label default_lbl);
1308-
let ret = do_convert tac_cases default in
1309-
Vec.push tac_cases (Assign { rd; rs = ret });
1310-
Vec.push tac_cases (Jump out);
1311-
1312-
List.iter (fun i ->
1313-
if not (List.mem i visited) then (
1314-
correspondence.(i - mn) <- default_lbl
1315-
)
1316-
) (List.init (mx - mn) (fun i -> i + mn));
1317-
1318-
(* Store the correct order of jump table *)
1319-
Vec.push tac_cases (Label out);
1320-
Vec.push global_inst (ExtArray
1321-
{ label = table; values = Array.to_list correspondence; elem_size = 8 });
1322-
1323-
(* Deduplicate possibilities and jump there *)
1324-
let possibilities =
1325-
Array.to_list correspondence |> Stringset.of_list |> Stringset.to_seq |> List.of_seq
1326-
in
1327-
1328-
Vec.push tac (JumpIndirect { rs = target; possibilities });
1329-
Vec.append tac tac_cases;
1236+
let mx = List.fold_left (fun mx x -> max mx x) (-2147483647-1) values in
1237+
let mn = List.fold_left (fun mn x -> min mn x) 2147483647 values in
1238+
1239+
(* Sparse values, convert to if-else *)
1240+
if mx - mn >= 256 then (
1241+
let ifexit = new_label "match_ifexit_" in
1242+
List.iter2 (fun x (_, expr) ->
1243+
let equal = new_temp Mtype.T_bool in
1244+
let v = new_temp Mtype.T_int in
1245+
let ifso = new_label "match_ifso_" in
1246+
let ifnot = new_label "match_ifnot_" in
1247+
1248+
Vec.push tac (AssignInt { rd = v; imm = x });
1249+
Vec.push tac (Eq { rd = equal; rs1 = index; rs2 = v });
1250+
Vec.push tac (Branch { cond = equal; ifso; ifnot });
1251+
1252+
(* Generate the match case *)
1253+
Vec.push tac (Label ifso);
1254+
let ret = do_convert tac expr in
1255+
Vec.push tac (Assign { rd; rs = ret });
1256+
Vec.push tac (Jump ifexit);
1257+
1258+
Vec.push tac (Label ifnot);
1259+
()
1260+
) values cases;
1261+
1262+
(* The last ifnot corresponds to the default case *)
1263+
let ret = do_convert tac default in
1264+
Vec.push tac (Assign { rd; rs = ret });
1265+
Vec.push tac (Jump ifexit);
1266+
1267+
Vec.push tac (Label ifexit)
1268+
)
1269+
1270+
(* Dense values, emit a jump table *)
1271+
else (
1272+
let table = new_label "jumptable_int_" in
1273+
let jump = new_label "do_jump_int_" in
1274+
let jumps = List.init (mx - mn + 1) (fun _ -> new_label "jumptable_int_") in
1275+
let out = new_label "jumptable_int_out_" in
1276+
let default_lbl = new_label "jumptable_default_" in
1277+
1278+
(* If the value is outside the min/max range, jump to default *)
1279+
let inrange = new_temp Mtype.T_bool in
1280+
let maximum = new_temp Mtype.T_int in
1281+
let minimum = new_temp Mtype.T_int in
1282+
let _1 = new_temp Mtype.T_bool in
1283+
let _2 = new_temp Mtype.T_bool in
1284+
1285+
(* Evaluate (x < max) && (x > min), which is the range where we can use jump table *)
1286+
Vec.push tac (AssignInt { rd = maximum; imm = mx });
1287+
Vec.push tac (AssignInt { rd = minimum; imm = mn });
1288+
Vec.push tac (Leq { rd = _1; rs1 = index; rs2 = maximum });
1289+
Vec.push tac (Geq { rd = _2; rs1 = index; rs2 = minimum });
1290+
Vec.push tac (And { rd = inrange; rs1 = _1; rs2 = _2 });
1291+
Vec.push tac (Branch { cond = inrange; ifso = jump; ifnot = default_lbl });
1292+
1293+
(* Load the address *)
1294+
Vec.push tac (Label jump);
1295+
1296+
let jtable = new_temp Mtype.T_bytes in
1297+
let ptr_sz = new_temp Mtype.T_int in
1298+
let off = new_temp Mtype.T_int in
1299+
let altered = new_temp Mtype.T_bytes in
1300+
let target = new_temp Mtype.T_bytes in
1301+
1302+
Vec.push tac (AssignLabel { rd = jtable; imm = table });
1303+
Vec.push tac (AssignInt { rd = ptr_sz; imm = pointer_size });
1304+
1305+
(* We must also minus the minimum, unlike switch_constr *)
1306+
let min_var = new_temp Mtype.T_int in
1307+
let ind_2 = new_temp Mtype.T_int in
1308+
1309+
Vec.push tac (AssignInt { rd = min_var; imm = mn });
1310+
Vec.push tac (Sub { rd = ind_2; rs1 = index; rs2 = min_var });
13301311

1312+
(* Now find which address to jump to *)
1313+
Vec.push tac (Mul { rd = off; rs1 = ind_2; rs2 = ptr_sz });
1314+
Vec.push tac (Add { rd = altered; rs1 = jtable; rs2 = off });
1315+
Vec.push tac (Load { rd = target; rs = altered; offset = 0; byte = pointer_size });
1316+
1317+
let visited = Vec.empty () in
1318+
let correspondence = Array.make (mx - mn + 1) "_uninit" in
1319+
1320+
(* For each label, generate the code of it *)
1321+
let tac_cases = Vec.empty () in
1322+
1323+
List.iter2 (fun value (_, expr) ->
1324+
let lbl = List.nth jumps (value - mn) in
1325+
1326+
Vec.push tac_cases (Label lbl);
1327+
let ret = do_convert tac_cases expr in
1328+
Vec.push tac_cases (Assign { rd; rs = ret });
1329+
Vec.push tac_cases (Jump out);
1330+
Vec.push visited value;
1331+
correspondence.(value - mn) <- lbl
1332+
) values cases;
1333+
1334+
(* For each values in the (min, max) range, redirect them into default *)
1335+
let visited = visited |> Vec.to_list in
1336+
1337+
Vec.push tac_cases (Label default_lbl);
1338+
let ret = do_convert tac_cases default in
1339+
Vec.push tac_cases (Assign { rd; rs = ret });
1340+
Vec.push tac_cases (Jump out);
1341+
1342+
List.iter (fun i ->
1343+
if not (List.mem i visited) then (
1344+
correspondence.(i - mn) <- default_lbl
13311345
)
1332-
1333-
| _ -> failwith "TODO: unsupported switch constant type");
1346+
) (List.init (mx - mn) (fun i -> i + mn));
1347+
1348+
(* Store the correct order of jump table *)
1349+
Vec.push tac_cases (Label out);
1350+
Vec.push global_inst (ExtArray
1351+
{ label = table; values = Array.to_list correspondence; elem_size = 8 });
1352+
1353+
(* Deduplicate possibilities and jump there *)
1354+
let possibilities =
1355+
Array.to_list correspondence |> Stringset.of_list |> Stringset.to_seq |> List.of_seq
1356+
in
1357+
1358+
Vec.push tac (JumpIndirect { rs = target; possibilities });
1359+
Vec.append tac tac_cases;);
13341360

13351361
rd
13361362
)

src/riscv_ssa.ml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,15 @@ let rec sizeof ty =
250250
| Mtype.T_constr id -> pointer_size
251251
| Mtype.T_fixedarray _ -> pointer_size
252252
| Mtype.T_trait _ -> pointer_size
253-
(* | Mtype.T_optimized_option { elem } -> pointer_size *)
253+
254+
(* Optimized option uses special values to indicate None for integer types *)
255+
(* So the size is equal to the underlying integer type *)
256+
| Mtype.T_optimized_option { elem } -> sizeof elem
257+
258+
(* Same size as the underlying type *)
259+
| Mtype.T_maybe_uninit x -> sizeof x
260+
254261
(* | Mtype.T_any { name } -> pointer_size *)
255-
(* | Mtype.T_maybe_uninit x -> sizeof x *)(*Same size as the contained type *)
256262
(* | Mtype.T_error_value_result { ok; err; id } -> sizeof ok + sizeof err + pointer_size *)
257263
| _ -> failwith ("riscv_ssa.ml: cannot calculate size for type: "^ Mtype.to_string ty)
258264
;;

0 commit comments

Comments
 (0)