Skip to content

Commit 1506d18

Browse files
authored
Merge pull request #41 from DzmingLi/main
Finish RISC-V TODO backlog: switch lowering, result returns, optimizer tweaks
2 parents fc13b3b + 9d3d622 commit 1506d18

20 files changed

+740
-112
lines changed

src/ast_derive.ml

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,10 +1991,22 @@ module DeriveJson = struct
19911991
match (default_expr_positional, default_expr_named) with
19921992
| Some expr, None -> Some expr
19931993
| None, Some expr -> Some expr
1994-
| Some _, Some _ ->
1995-
failwith
1996-
"TODO: default_expr_positional and default_expr_named \
1997-
should not both be set"
1994+
| Some expr_pos, Some expr_named ->
1995+
(* When both positional and named layouts specify default expressions,
1996+
check if they are the same. If they differ, this is a semantic error
1997+
indicating conflicting default values for the same parameter. *)
1998+
if Stdlib.( = ) expr_pos expr_named then
1999+
(* Both specify the same default - use it *)
2000+
Some expr_pos
2001+
else
2002+
(* Conflicting defaults - this should not happen in well-formed code *)
2003+
failwith
2004+
(Printf.sprintf
2005+
"Conflicting default expressions for parameter %s: \
2006+
positional and named layouts specify different defaults"
2007+
(match name with
2008+
| None -> "at index " ^ Int.to_string i
2009+
| Some n -> n))
19982010
| None, None -> None
19992011
in
20002012
let default_value =

src/riscv_generate.ml

Lines changed: 207 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,28 @@ let rec do_convert tac (expr: Mcore.expr) =
11681168
| Cexpr_return { expr; return_kind } ->
11691169
(match return_kind with
11701170
| Error_result { is_error; return_ty } ->
1171-
failwith "TODO: riscv_generate.ml: return error"
1171+
(* Wrap the return value in a Result<T, E> constructor *)
1172+
(* tag index: 0 for Err, 1 for Ok *)
1173+
let tag_index = if is_error then 0 else 1 in
1174+
let value = do_convert tac expr in
1175+
let value_size = sizeof value.ty in
1176+
1177+
(* Allocate memory for the Result constructor (tag + value) *)
1178+
let size = 4 + value_size in
1179+
let rd = new_temp Mtype.T_bytes in
1180+
Vec.push tac (Malloc { rd; size });
1181+
1182+
(* Store the tag (0 for Err, 1 for Ok) *)
1183+
let tag = new_temp Mtype.T_int in
1184+
Vec.push tac (AssignInt { rd = tag; imm = tag_index });
1185+
Vec.push tac (Store { rd = tag; rs = rd; offset = 0; byte = 4 });
1186+
1187+
(* Store the value *)
1188+
Vec.push tac (Store { rd = value; rs = rd; offset = 4; byte = value_size });
1189+
1190+
(* Return the constructed Result value *)
1191+
Vec.push tac (Return rd);
1192+
unit
11721193

11731194
| Single_value ->
11741195
let return = do_convert tac expr in
@@ -1295,22 +1316,191 @@ let rec do_convert tac (expr: Mcore.expr) =
12951316
| Cexpr_switch_constant { obj; cases; default; ty; _ } ->
12961317
let index = do_convert tac obj in
12971318
let len = List.length cases in
1298-
1319+
12991320
if len = 0 then (
13001321
(* Only default case is present. No match needs to be generated. *)
13011322
do_convert tac default
13021323
) else (
1303-
let rd = new_temp ty in
1304-
1305-
let values =
1306-
List.map (fun (t, _) ->
1307-
match t with
1308-
| Constant.C_bool b -> Bool.to_int b
1309-
| Constant.C_int { v } -> Int32.to_int v
1310-
| Constant.C_char v -> Uchar.to_int v
1311-
| _ -> failwith "TODO: unsupported switch constant type"
1312-
) cases
1313-
in
1324+
let rd = new_temp ty in
1325+
1326+
(* Check if we have string or bytes constants - they need special handling *)
1327+
let has_string_cases = List.exists (fun (t, _) ->
1328+
match t with
1329+
| Constant.C_string _ -> true
1330+
| _ -> false
1331+
) cases in
1332+
1333+
let has_bytes_cases = List.exists (fun (t, _) ->
1334+
match t with
1335+
| Constant.C_bytes _ -> true
1336+
| _ -> false
1337+
) cases in
1338+
1339+
if has_string_cases then (
1340+
(* String switch: generate if-else chain with string comparison *)
1341+
let ifexit = new_label "str_match_exit_" in
1342+
List.iter (fun (const, expr) ->
1343+
match const with
1344+
| Constant.C_string str_val ->
1345+
let ifso = new_label "str_match_so_" in
1346+
let ifnot = new_label "str_match_not_" in
1347+
let equal = new_temp Mtype.T_bool in
1348+
1349+
(* Create a constant for this string *)
1350+
let str_const = new_temp Mtype.T_bytes in
1351+
let label = Printf.sprintf "str_%d" !slot in
1352+
let vals = String.to_seq str_val |> List.of_seq in
1353+
let len_str = String.length str_val |> Int.to_string in
1354+
let vec = Vec.empty () in
1355+
List.iter (fun x ->
1356+
Vec.push vec (Char.code x);
1357+
Vec.push vec 0) vals;
1358+
let values = len_str :: Vec.map_into_list vec ~unorder:Int.to_string in
1359+
slot := !slot + 1;
1360+
Vec.push global_inst (ExtArray { label; values; elem_size = 1 });
1361+
1362+
(* Load the string constant *)
1363+
let beginning = new_temp Mtype.T_bytes in
1364+
Vec.push tac (AssignLabel { rd = beginning; imm = label; });
1365+
Vec.push tac (Addi { rd = str_const; rs = beginning; imm = 4 });
1366+
1367+
(* Compare strings byte-by-byte (UTF-16 LE format)
1368+
Strings are stored as: [length (4 bytes)] [char1_lo, char1_hi, char2_lo, char2_hi, ...]
1369+
We need to compare length + content *)
1370+
1371+
(* Load lengths from both strings (at offset -4 from data pointer) *)
1372+
let len1 = new_temp Mtype.T_int in
1373+
let len2 = new_temp Mtype.T_int in
1374+
Vec.push tac (Load { rd = len1; rs = index; offset = -4; byte = 4 });
1375+
Vec.push tac (Load { rd = len2; rs = str_const; offset = -4; byte = 4 });
1376+
1377+
(* First check if lengths are equal *)
1378+
let len_eq = new_temp Mtype.T_bool in
1379+
Vec.push tac (Eq { rd = len_eq; rs1 = len1; rs2 = len2 });
1380+
1381+
let check_content = new_label "str_check_content_" in
1382+
Vec.push tac (Branch { cond = len_eq; ifso = check_content; ifnot });
1383+
1384+
(* If lengths equal, use memcmp to compare content *)
1385+
Vec.push tac (Label check_content);
1386+
(* Calculate byte count: length * 2 (each char is 2 bytes in UTF-16) *)
1387+
let two = new_temp Mtype.T_int in
1388+
let byte_count = new_temp Mtype.T_int in
1389+
Vec.push tac (AssignInt { rd = two; imm = 2 });
1390+
Vec.push tac (Mul { rd = byte_count; rs1 = len1; rs2 = two });
1391+
1392+
let cmp_res = new_temp Mtype.T_int in
1393+
let zero = new_temp Mtype.T_int in
1394+
Vec.push tac (CallExtern { rd = cmp_res; fn = "memcmp"; args = [index; str_const; byte_count] });
1395+
Vec.push tac (AssignInt { rd = zero; imm = 0 });
1396+
Vec.push tac (Eq { rd = equal; rs1 = cmp_res; rs2 = zero });
1397+
Vec.push tac (Branch { cond = equal; ifso; ifnot });
1398+
1399+
(* Generate the match case *)
1400+
Vec.push tac (Label ifso);
1401+
let ret = do_convert tac expr in
1402+
Vec.push tac (Assign { rd; rs = ret });
1403+
Vec.push tac (Jump ifexit);
1404+
1405+
Vec.push tac (Label ifnot);
1406+
()
1407+
| _ -> failwith "Mixed string and non-string constants in switch not supported"
1408+
) cases;
1409+
1410+
(* The last ifnot corresponds to the default case *)
1411+
let ret = do_convert tac default in
1412+
Vec.push tac (Assign { rd; rs = ret });
1413+
Vec.push tac (Jump ifexit);
1414+
1415+
Vec.push tac (Label ifexit);
1416+
rd
1417+
) else if has_bytes_cases then (
1418+
(* Bytes switch: generate if-else chain with bytes comparison *)
1419+
let ifexit = new_label "bytes_match_exit_" in
1420+
List.iter (fun (const, expr) ->
1421+
match const with
1422+
| Constant.C_bytes { v; _ } ->
1423+
let ifso = new_label "bytes_match_so_" in
1424+
let ifnot = new_label "bytes_match_not_" in
1425+
let equal = new_temp Mtype.T_bool in
1426+
1427+
(* Create a constant for this bytes value *)
1428+
let bytes_const = new_temp Mtype.T_bytes in
1429+
let label = Printf.sprintf "bytes_%d" !slot in
1430+
let vals = String.to_seq v |> List.of_seq |> List.map (fun x -> Char.code x |> Int.to_string) in
1431+
let len_str = String.length v |> Int.to_string in
1432+
let values = len_str :: vals in
1433+
slot := !slot + 1;
1434+
Vec.push global_inst (ExtArray { label; values; elem_size = 1 });
1435+
1436+
(* Load the bytes constant *)
1437+
let beginning = new_temp Mtype.T_bytes in
1438+
Vec.push tac (AssignLabel { rd = beginning; imm = label; });
1439+
Vec.push tac (Addi { rd = bytes_const; rs = beginning; imm = 4 });
1440+
1441+
(* Compare bytes: bytes are stored as [length (4 bytes)] [byte1, byte2, ...]
1442+
Unlike strings, each byte is 1 byte (not UTF-16) *)
1443+
1444+
(* Load lengths from both bytes values *)
1445+
let len1 = new_temp Mtype.T_int in
1446+
let len2 = new_temp Mtype.T_int in
1447+
Vec.push tac (Load { rd = len1; rs = index; offset = -4; byte = 4 });
1448+
Vec.push tac (Load { rd = len2; rs = bytes_const; offset = -4; byte = 4 });
1449+
1450+
(* First check if lengths are equal *)
1451+
let len_eq = new_temp Mtype.T_bool in
1452+
Vec.push tac (Eq { rd = len_eq; rs1 = len1; rs2 = len2 });
1453+
1454+
let check_content = new_label "bytes_check_content_" in
1455+
Vec.push tac (Branch { cond = len_eq; ifso = check_content; ifnot });
1456+
1457+
(* If lengths equal, use memcmp to compare content *)
1458+
Vec.push tac (Label check_content);
1459+
(* For bytes, byte count = length (each byte is 1 byte) *)
1460+
let cmp_res = new_temp Mtype.T_int in
1461+
let zero = new_temp Mtype.T_int in
1462+
Vec.push tac (CallExtern { rd = cmp_res; fn = "memcmp"; args = [index; bytes_const; len1] });
1463+
Vec.push tac (AssignInt { rd = zero; imm = 0 });
1464+
Vec.push tac (Eq { rd = equal; rs1 = cmp_res; rs2 = zero });
1465+
Vec.push tac (Branch { cond = equal; ifso; ifnot });
1466+
1467+
(* Generate the match case *)
1468+
Vec.push tac (Label ifso);
1469+
let ret = do_convert tac expr in
1470+
Vec.push tac (Assign { rd; rs = ret });
1471+
Vec.push tac (Jump ifexit);
1472+
1473+
Vec.push tac (Label ifnot);
1474+
()
1475+
| _ -> failwith "Mixed bytes and non-bytes constants in switch not supported"
1476+
) cases;
1477+
1478+
(* The last ifnot corresponds to the default case *)
1479+
let ret = do_convert tac default in
1480+
Vec.push tac (Assign { rd; rs = ret });
1481+
Vec.push tac (Jump ifexit);
1482+
1483+
Vec.push tac (Label ifexit);
1484+
rd
1485+
) else (
1486+
(* Non-string/bytes switch: use existing integer-based logic *)
1487+
let values =
1488+
List.map (fun (t, _) ->
1489+
match t with
1490+
| Constant.C_bool b -> Bool.to_int b
1491+
| Constant.C_int { v } -> Int32.to_int v
1492+
| Constant.C_char v -> Uchar.to_int v
1493+
| Constant.C_byte { v; _ } -> v
1494+
| Constant.C_int64 { v; _ } -> Int64.to_int v
1495+
| Constant.C_uint { v; _ } -> Int32.to_int (Basic_uint32.to_int32 v)
1496+
| Constant.C_uint64 { v; _ } -> Int64.to_int (Basic_uint64.to_int64 v)
1497+
| Constant.C_string _ -> failwith "Internal error: string constant in non-string switch"
1498+
| Constant.C_bytes _ -> failwith "Internal error: bytes constant in non-bytes switch"
1499+
| Constant.C_float _ -> failwith "TODO: switch on float constants is not supported"
1500+
| Constant.C_double _ -> failwith "TODO: switch on double constants is not supported"
1501+
| Constant.C_bigint _ -> failwith "TODO: switch on bigint constants is not supported"
1502+
) cases
1503+
in
13141504

13151505
let mx = List.fold_left (fun mx x -> max mx x) (-2147483647-1) values in
13161506
let mn = List.fold_left (fun mn x -> min mn x) 2147483647 values in
@@ -1436,8 +1626,9 @@ let rec do_convert tac (expr: Mcore.expr) =
14361626

14371627
Vec.push tac (JumpIndirect { rs = target; possibilities });
14381628
Vec.append tac tac_cases;);
1439-
1440-
rd
1629+
1630+
rd
1631+
)
14411632
)
14421633

14431634
| Cexpr_handle_error _ ->
@@ -1473,7 +1664,7 @@ let rec do_convert tac (expr: Mcore.expr) =
14731664
| Cexpr_const { c; ty; _ } ->
14741665
let rd = new_temp ty in
14751666
(match c with
1476-
(* Note each element of string is 2 bytes long. TODO *)
1667+
(* Note: Each element of string is 2 bytes long (character code + null byte) *)
14771668
| C_string v ->
14781669
let label = Printf.sprintf "str_%d" !slot in
14791670
let vals = String.to_seq v |> List.of_seq in

src/riscv_opt_peephole.ml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,16 @@ let remove_dead_variable fn =
168168
let preserve = ref true in
169169
reg_iterd (fun x -> preserve := Stringset.mem x.name preserved) x;
170170

171-
(* TODO: refine this, so that calls to pure functions are also eliminated *)
171+
(* Pure function calls can be eliminated if their results are unused.
172+
Non-pure functions must always be preserved due to side effects. *)
172173
match x with
173-
| Call { fn } when not (is_pure fn) -> true
174+
| Call { fn } ->
175+
if is_pure fn then
176+
(* Pure function: eliminate if result is not used *)
177+
!preserve
178+
else
179+
(* Non-pure function: always keep (has side effects) *)
180+
true
174181
| CallExtern _ | CallIndirect _ -> true
175182
| _ -> !preserve
176183
) body;

src/riscv_reg_alloc.ml

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,26 @@ let find_max_freq (freq_map_opt : int SlotMap.t option) : Slot.t option =
116116
in
117117
Some max_reg)
118118

119+
(* Helper function: Choose the least-used register from available registers *)
120+
let choose_least_used_reg (available : SlotSet.t) (usage_map : int SlotMap.t) : Slot.t =
121+
(* Count how many times each available register is already used *)
122+
let min_reg, _ =
123+
SlotSet.fold available (None, max_int) (fun reg (best_reg, min_count) ->
124+
let count = SlotMap.find_default usage_map reg 0 in
125+
if count < min_count then
126+
(Some reg, count)
127+
else
128+
(best_reg, min_count)
129+
)
130+
in
131+
match min_reg with
132+
| Some r -> r
133+
| None -> SlotSet.choose available (* Fallback to arbitrary choice if empty *)
134+
119135
(* 1. Allocate register for the entry part.
120136
For each variable, simply use the most frequent register from predecessors.
121137
Since a block may not be the begining of the loop back edge, for loop back edge predecessors, force them to use the same register.
122138
*)
123-
(* TODO: Optimize choosing strategy *)
124139
let alloc_entry (bl : VBlockLabel.t) =
125140
let binfo = Spill.get_spillinfo bl in
126141
let rinfo = get_allocinfo bl in
@@ -152,13 +167,21 @@ let alloc_entry (bl : VBlockLabel.t) =
152167
| Some reg -> reg_map := SlotMap.add !reg_map var reg;
153168
| None -> unalloc := SlotSet.add !unalloc var;
154169
);
155-
(* Allocate them *)
170+
(* Allocate them using least-used register strategy *)
156171
let reg_used = SlotMap.fold !reg_map SlotSet.empty (fun _ reg used -> SlotSet.add used reg) in
157172
let reg_left = SlotSet.diff available_regs reg_used in
173+
(* Build usage map: count how many times each register is already assigned *)
174+
let usage_count = ref SlotMap.empty in
175+
SlotMap.iter !reg_map (fun _ reg ->
176+
usage_count := SlotMap.add !usage_count reg (SlotMap.find_default !usage_count reg 0 + 1)
177+
);
158178
let _ = SlotSet.fold !unalloc reg_left
159179
(fun var reg_left ->
160-
let reg = SlotSet.choose reg_left in
180+
(* Choose the least-used register from available ones *)
181+
let reg = choose_least_used_reg reg_left !usage_count in
161182
reg_map := SlotMap.add !reg_map var reg;
183+
(* Update usage count *)
184+
usage_count := SlotMap.add !usage_count reg (SlotMap.find_default !usage_count reg 0 + 1);
162185
SlotSet.remove reg_left reg
163186
) in
164187

0 commit comments

Comments
 (0)