Skip to content

Commit 676893c

Browse files
authored
RISC-V: add reg_access and test its usages in C and Python (#2895)
* RISC-V: add reg_access and test its usages in C and Python reg_access is a convenience wrapper over the `operands` array that filters the register operands (including those used as memory base address) and returns them sorted into read and written registers. It wasn't implemented for RISC-V, this PR implements it. The following decisions were made for RISC-V: 1- System registers (CSRs) are not registers This follows existing Capstone convention, where almost every archiceture that have system registers except x86 treats them as a seperate address space. From a purely practical POV, the reg_access function API returns registers as an array of integers, and the address space of normal registers intersects with that of system registers so there is nothing in the return value to distinguish them. 2- PC is not an implicit register Whenever an instruction reads PC (e.g. all call-ish instructions JAL[R]?) this is NOT counted as an implicit read of the PC. The reason is that the PC is somewhat "second class" in RISC-V, it's an archiectural register but has no actual index and can never be directly written to by any instruction in any standard extension no matter the privliege. Meanwhile, all instruction that read the PC have names that make it obvious they read the PC so adding that information to the implicit reads array would be redundant.
1 parent 1084d36 commit 676893c

8 files changed

Lines changed: 294 additions & 2 deletions

File tree

arch/RISCV/RISCVMapping.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,39 @@ riscv_insn RISCV_map_insn(const char *name)
434434
return RISCV_INS_INVALID;
435435
}
436436

437+
void RISCV_reg_access(const cs_insn *insn, cs_regs regs_read,
438+
uint8_t *regs_read_count, cs_regs regs_write,
439+
uint8_t *regs_write_count)
440+
{
441+
const cs_riscv *riscv = &(insn->detail->riscv);
442+
uint8_t read_count = 0;
443+
uint8_t write_count = 0;
444+
445+
for (int j = 0; j < riscv->op_count; j++) {
446+
const cs_riscv_op *op = &riscv->operands[j];
447+
448+
if (op->type == RISCV_OP_REG) {
449+
if ((op->access & CS_AC_WRITE) &&
450+
!arr_exist(regs_write, write_count, op->reg)) {
451+
regs_write[write_count++] = (uint16_t)op->reg;
452+
}
453+
if ((op->access & CS_AC_READ) &&
454+
!arr_exist(regs_read, read_count, op->reg)) {
455+
regs_read[read_count++] = (uint16_t)op->reg;
456+
}
457+
} else if (op->type == RISCV_OP_MEM) {
458+
if (op->mem.base != RISCV_REG_INVALID &&
459+
!arr_exist(regs_read, read_count, op->mem.base)) {
460+
regs_read[read_count++] =
461+
(uint16_t)op->mem.base;
462+
}
463+
}
464+
}
465+
466+
*regs_read_count = read_count;
467+
*regs_write_count = write_count;
468+
}
469+
437470
void RISCV_init(MCRegisterInfo *MRI)
438471
{
439472
MCRegisterInfo_InitMCRegisterInfo(MRI, RISCVRegDesc, RISCV_REG_ENDING,

arch/RISCV/RISCVMapping.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,8 @@ riscv_insn RISCV_map_insn(const char *name);
3838

3939
void RISCV_init(MCRegisterInfo *MRI);
4040

41+
void RISCV_reg_access(const cs_insn *insn, cs_regs regs_read,
42+
uint8_t *regs_read_count, cs_regs regs_write,
43+
uint8_t *regs_write_count);
44+
4145
#endif

arch/RISCV/RISCVModule.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cs_err RISCV_global_init(cs_struct *ud)
3131
ud->group_name = RISCV_group_name;
3232
ud->insn_map = RISCV_insns;
3333
ud->insn_map_size = RISCV_insn_count;
34+
ud->reg_access = RISCV_reg_access;
3435

3536
return CS_ERR_OK;
3637
}

bindings/python/tests/test_all.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import test_customized_mnem
77
import test_compatibility_layer
88
import test_riscv_sysreg
9+
import test_riscv_reg_access
910

1011
errors = []
1112
errors.extend(test_lite.test_class())
@@ -14,6 +15,7 @@
1415
errors.extend(test_customized_mnem.test())
1516
errors.extend(test_compatibility_layer.test_compatibility())
1617
errors.extend(test_riscv_sysreg.test())
18+
errors.extend(test_riscv_reg_access.test())
1719

1820
if errors:
1921
print("Some errors happened. Please check the output")
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import unittest
2+
from capstone import *
3+
from capstone.riscv import *
4+
import unittest
5+
6+
class TestRiscvRegAccess(unittest.TestCase):
7+
def setUp(self):
8+
self.cs = Cs(CS_ARCH_RISCV, CS_MODE_RISCV64)
9+
self.cs.option(CS_OPT_DETAIL, CS_OPT_DETAIL_REAL | CS_OPT_ON)
10+
11+
def test_addi(self):
12+
# addi a0, a1, 10
13+
code = b"\x13\x85\xa5\x00"
14+
insns = list(self.cs.disasm(code, 0))
15+
self.assertEqual(len(insns), 1)
16+
insn = insns[0]
17+
18+
read, write = insn.regs_access()
19+
# a1 = RISCV_REG_X11, a0 = RISCV_REG_X10
20+
self.assertIn(RISCV_REG_X11, read)
21+
self.assertIn(RISCV_REG_X10, write)
22+
self.assertEqual(len(read), 1)
23+
self.assertEqual(len(write), 1)
24+
25+
def test_jalr(self):
26+
# jalr ra, a1, 0 -> 0x000580e7 (rd=x1=ra, rs1=x11=a1, imm=0)
27+
code = b"\xe7\x80\x05\x00"
28+
insns = list(self.cs.disasm(code, 0))
29+
self.assertEqual(len(insns), 1)
30+
insn = insns[0]
31+
32+
read, write = insn.regs_access()
33+
# ra = RISCV_REG_X1
34+
self.assertIn(RISCV_REG_X11, read)
35+
self.assertIn(RISCV_REG_X1, write)
36+
self.assertEqual(len(read), 1)
37+
self.assertEqual(len(write), 1)
38+
39+
def test_lb(self):
40+
# lb a0, 0(sp)
41+
code = b"\x03\x05\x01\x00"
42+
insns = list(self.cs.disasm(code, 0))
43+
self.assertEqual(len(insns), 1)
44+
insn = insns[0]
45+
46+
read, write = insn.regs_access()
47+
# sp = RISCV_REG_X2
48+
self.assertIn(RISCV_REG_X2, read)
49+
self.assertIn(RISCV_REG_X10, write)
50+
self.assertEqual(len(read), 1)
51+
self.assertEqual(len(write), 1)
52+
53+
def test_caddi(self):
54+
# c.addi a0, 10 (0x0529)
55+
code = b"\x29\x05"
56+
insns = list(self.cs.disasm(code, 0))
57+
self.assertEqual(len(insns), 1)
58+
insn = insns[0]
59+
60+
read, write = insn.regs_access()
61+
# x10 is both read and written
62+
self.assertIn(RISCV_REG_X10, read)
63+
self.assertIn(RISCV_REG_X10, write)
64+
self.assertEqual(len(read), 1)
65+
self.assertEqual(len(write), 1)
66+
67+
def test_ecall(self):
68+
# ecall
69+
code = b"\x73\x00\x00\x00"
70+
insns = list(self.cs.disasm(code, 0))
71+
self.assertEqual(len(insns), 1)
72+
insn = insns[0]
73+
74+
read, write = insn.regs_access()
75+
self.assertEqual(len(read), 0)
76+
self.assertEqual(len(write), 0)
77+
78+
def test_csrrw(self):
79+
# csrrw a0, sstatus, a1
80+
code = b"\x73\x95\x05\x10"
81+
insns = list(self.cs.disasm(code, 0))
82+
self.assertEqual(len(insns), 1)
83+
insn = insns[0]
84+
85+
read, write = insn.regs_access()
86+
# CSRs should NOT be in the reg_access list
87+
self.assertIn(RISCV_REG_X11, read)
88+
self.assertIn(RISCV_REG_X10, write)
89+
self.assertEqual(len(read), 1)
90+
self.assertEqual(len(write), 1)
91+
92+
def test():
93+
loader = unittest.TestLoader()
94+
suite = loader.loadTestsFromTestCase(TestRiscvRegAccess)
95+
runner = unittest.TextTestRunner(verbosity=2)
96+
return runner.run(suite).failures
97+
98+
def main():
99+
unittest.main()
100+
101+
if __name__ == '__main__':
102+
main()

docs/cs_v6_release_guide.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,12 @@ Nonetheless, we hope this additional information is useful to you.
249249
Note that `+noalias` "overpowers" `noaliascompressed` in the second case: despite `+noaliascompressed` being false, meaning aliases are wanted for compressed instructions, `+noalias` being true means ALL aliases are supressed, and this takes precedence. Other than that, case 1 and case 3 work as intuitively expected, and case 4 is redundant.
250250

251251
So a single-sentence description of this table is: if `+noalias` is given then no aliases will be printed for any instruction, but if not given then aliases will be printed for non-compressed instruction and alias printing for compressed instruction futher checks `+noaliascompressed` before proceeding.
252+
- Added `reg_access` capstone callback to return all read and written registers for the instructions, including registers used as part of memory operands.
253+
* Note that `reg_access` does NOT treat CSRs as registers, detailed reasons for why can be found in [the PR implementing the feature](https://github.com/capstone-engine/capstone/pull/2895)
254+
* Note that `reg_access` does NOT treat reading the PC's value as reading a register, detailed reasons for why can be found in [the PR implementing the feature](https://github.com/capstone-engine/capstone/pull/2895)
255+
252256
> [!NOTE]
253-
> All extensions above are disabled by default unless enabled by their option name or the corresponding command line flag in cstool. Any other extension is always enabled and can't be disabled.
257+
> All `CS_MODE_RISCV_*` extensions above are disabled by default unless enabled by their option name or the corresponding command line flag in cstool. Any other extension is always enabled and can't be disabled.
254258
255259
> [!NOTE]
256260
> RISC-V has a massive, sprawling list of extensions, but Capstone's internal implementaton choice of using a 32-bit mode field is not enough to cover all of them. For now, those extension flags above were added because their encoding space is conflicting with either each other or other extensions. More flags can be added later if bug reports come in requesting finer-grained extension control. However, the current implementation using bitfields imposes a strict upper limit and would likely be refactored for a more expansive mechanism in the future. See [this issue](https://github.com/capstone-engine/capstone/issues/2848) for more details.

tests/unit/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.15)
33
enable_testing()
44
set(UNIT_TEST_SOURCES sstream.c utils.c)
55
if(CAPSTONE_RISCV_SUPPORT)
6-
list(APPEND UNIT_TEST_SOURCES riscv_op_count_iter.c riscv_sysreg.c)
6+
list(APPEND UNIT_TEST_SOURCES riscv_op_count_iter.c riscv_sysreg.c riscv_reg_access.c)
77
endif()
88
include_directories(include)
99

tests/unit/riscv_reg_access.c

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#include "unit_test.h"
2+
#include <capstone/capstone.h>
3+
#include <stdio.h>
4+
#include <string.h>
5+
6+
static bool test_reg_access(csh handle, const uint8_t *code, size_t code_size,
7+
const uint16_t *expected_read,
8+
size_t expected_read_count,
9+
const uint16_t *expected_write,
10+
size_t expected_write_count)
11+
{
12+
cs_insn *insn;
13+
size_t count = cs_disasm(handle, code, code_size, 0, 1, &insn);
14+
if (count == 0) {
15+
printf("Failed to disassemble instruction\n");
16+
return false;
17+
}
18+
// debugging print, useful but noisy
19+
//printf("\n\n======================= TEST GOT INSTRUCTION TEXT: %s %s \n\n======================= (num operands: %d)\n",
20+
// insn->mnemonic, insn->op_str, insn->detail->riscv.op_count);
21+
cs_regs regs_read, regs_write;
22+
uint8_t regs_read_count, regs_write_count;
23+
24+
cs_err err = cs_regs_access(handle, insn, regs_read, &regs_read_count,
25+
regs_write, &regs_write_count);
26+
if (err != CS_ERR_OK) {
27+
printf("cs_regs_access failed with error: %d\n", err);
28+
cs_free(insn, count);
29+
return false;
30+
}
31+
32+
bool success = true;
33+
if (regs_read_count != expected_read_count) {
34+
printf("Read count mismatch: expected %zu, got %u\n",
35+
expected_read_count, regs_read_count);
36+
success = false;
37+
} else {
38+
for (size_t i = 0; i < expected_read_count; i++) {
39+
bool found = false;
40+
for (size_t j = 0; j < regs_read_count; j++) {
41+
if (regs_read[j] == expected_read[i]) {
42+
found = true;
43+
break;
44+
}
45+
}
46+
if (!found) {
47+
printf("Expected read register %d not found\n",
48+
expected_read[i]);
49+
success = false;
50+
}
51+
}
52+
}
53+
54+
if (regs_write_count != expected_write_count) {
55+
printf("Write count mismatch: expected %zu, got %u\n",
56+
expected_write_count, regs_write_count);
57+
success = false;
58+
} else {
59+
for (size_t i = 0; i < expected_write_count; i++) {
60+
bool found = false;
61+
for (size_t j = 0; j < regs_write_count; j++) {
62+
if (regs_write[j] == expected_write[i]) {
63+
found = true;
64+
break;
65+
}
66+
}
67+
if (!found) {
68+
printf("Expected write register %d not found\n",
69+
expected_write[i]);
70+
success = false;
71+
}
72+
}
73+
}
74+
75+
cs_free(insn, count);
76+
return success;
77+
}
78+
79+
int main(void)
80+
{
81+
csh handle;
82+
if (cs_open(CS_ARCH_RISCV, CS_MODE_RISCV64, &handle) != CS_ERR_OK) {
83+
return 1;
84+
}
85+
cs_option(handle, CS_OPT_DETAIL, CS_OPT_DETAIL_REAL | CS_OPT_ON);
86+
87+
bool success[10];
88+
memset(success, true, sizeof(success));
89+
90+
// addi a0, a1, 10 -> 0x00a58513
91+
printf("Test 0: Testing addi a0, a1, 10\n");
92+
uint8_t addi_code[] = { 0x13, 0x85, 0xa5, 0x00 };
93+
uint16_t addi_read[] = { RISCV_REG_X11 }; // a1
94+
uint16_t addi_write[] = { RISCV_REG_X10 }; // a0
95+
success[0] = test_reg_access(handle, addi_code, sizeof(addi_code),
96+
addi_read, 1, addi_write, 1);
97+
// jalr ra, a1, 0 -> 0x000580e7 (rd=x1=ra, rs1=x11=a1, imm=0)
98+
printf("Test 1: Testing jalr ra, a1, 0\n");
99+
uint8_t jalr_code[] = { 0xe7, 0x80, 0x05, 0x00 };
100+
uint16_t jalr_read[] = { RISCV_REG_X11 };
101+
uint16_t jalr_write[] = { RISCV_REG_X1 }; // ra
102+
success[1] = test_reg_access(handle, jalr_code, sizeof(jalr_code),
103+
jalr_read, 1, jalr_write, 1);
104+
// lb a0, 0(sp) -> 0x00010503
105+
printf("Test 2: Testing lb a0, 0(sp)\n");
106+
uint8_t lb_code[] = { 0x03, 0x05, 0x01, 0x00 };
107+
uint16_t lb_read[] = { RISCV_REG_X2 }; // sp
108+
uint16_t lb_write[] = { RISCV_REG_X10 };
109+
success[2] = test_reg_access(handle, lb_code, sizeof(lb_code), lb_read,
110+
1, lb_write, 1);
111+
112+
// c.addi a0, 10 -> 0x0529
113+
printf("Test 3: Testing c.addi a0, 10\n");
114+
uint8_t caddi_code[] = { 0x29, 0x05 };
115+
uint16_t caddi_read[] = { RISCV_REG_X10 }; // x10 is both read and write
116+
uint16_t caddi_write[] = { RISCV_REG_X10 };
117+
success[3] = test_reg_access(handle, caddi_code, sizeof(caddi_code),
118+
caddi_read, 1, caddi_write, 1);
119+
120+
// ecall -> 0x00000073
121+
printf("Test 4: Testing ecall\n");
122+
uint8_t ecall_code[] = { 0x73, 0x00, 0x00, 0x00 };
123+
success[4] = test_reg_access(handle, ecall_code, sizeof(ecall_code),
124+
NULL, 0, NULL, 0);
125+
126+
// csrrw a0, sstatus, a1 -> 0x10059533 (Wait, CSRRW is 0x10059573?)
127+
// 0x10059573: csrrw x10, sstatus, x11
128+
printf("Test 5: Testing csrrw a0, sstatus, a1\n");
129+
uint8_t csrrw_code[] = { 0x73, 0x95, 0x05, 0x10 };
130+
uint16_t csrrw_read[] = {
131+
RISCV_REG_X11
132+
}; // sstatus (CSR) should NOT be here
133+
uint16_t csrrw_write[] = { RISCV_REG_X10 };
134+
success[5] = test_reg_access(handle, csrrw_code, sizeof(csrrw_code),
135+
csrrw_read, 1, csrrw_write, 1);
136+
137+
cs_close(&handle);
138+
bool all_success = true;
139+
for (int i = 0; i < sizeof(success) / sizeof(success[0]); i++) {
140+
if (!success[i]) {
141+
printf("Test %d failed\n", i);
142+
all_success = false;
143+
}
144+
}
145+
return all_success ? 0 : 1;
146+
}

0 commit comments

Comments
 (0)