Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tests/test_memblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,45 @@ def test_memblock_added_default_named(self):
self.assertIs(pyrtl.working_block().get_memblock_by_name(mem.name), mem)


class RTLMemBlockErrorTests(unittest.TestCase):
def setUp(self):
pyrtl.reset_working_block()

def test_negative_bitwidth(self):
with self.assertRaises(pyrtl.PyrtlError):
pyrtl.MemBlock(-1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines like these are a little easier to read with keyword arguments (pyrtl.MemBlock(bitwidth=-1, addrwidth=1)) as the reader may not remember whether bitwidth or addrwidth comes first.

If you agree and want to change this, please change it consistently throughout this commit


def test_negative_addrwidth(self):
with self.assertRaises(pyrtl.PyrtlError):
pyrtl.MemBlock(1, -1)

def test_memindex_bitwidth_more_than_addrwidth(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this greater_than instead of more_than ?

Regardless, the word choice here should be consistent with the next test (more vs larger)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the name to test_memindex_bitwidth_greater_than_addrwidth, but I left test_memblock_write_data_larger_than_memory_bitwidth as it is because I think when we talk about numbers it is better to say greater and less and when we talk about data it is better to say larger and smaller. But I can change it if you disagree.

mem = pyrtl.MemBlock(1, 1)
mem_in = pyrtl.Input(2, 'mem_in')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name mem_in is confusing in this commit as it's sometimes the address, and sometimes the data to write. Seems clearer to always call the address mem_addr?

mem_out = pyrtl.Output(1, 'mem_out')
with self.assertRaises(pyrtl.PyrtlError):
mem_out <<= mem[mem_in]

def test_memblock_write_data_larger_than_memory_bidwidth(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: bidwidth :)

mem = pyrtl.MemBlock(1, 1)
mem_addr = pyrtl.Input(1, 'mem_addr')
mem_in = pyrtl.Input(2, 'mem_in')
with self.assertRaises(pyrtl.PyrtlError):
mem[mem_addr] <<= mem_in

def test_memblock_enable_signal_not_1_bit(self):
mem = pyrtl.MemBlock(1, 1)
mem_addr = pyrtl.Input(1, 'mem_addr')
mem_in = pyrtl.Input(1, 'mem_in')
with self.assertRaises(pyrtl.PyrtlError):
mem[mem_addr] <<= pyrtl.MemBlock.EnabledWrite(mem_in, enable=pyrtl.Input(2))

def test_read_ports_exception(self):
mem = pyrtl.MemBlock(1, 1)
with self.assertRaises(pyrtl.PyrtlError):
mem.read_ports()


class MemIndexedTests(unittest.TestCase):
def setUp(self):
pyrtl.reset_working_block()
Expand Down Expand Up @@ -229,6 +268,53 @@ def test_write_memindexed_ior(self):
self.assertEqual(self.mem1.num_read_ports, 1)
self.assertEqual(self.mem2.num_write_ports, 1)

def test_memindexed_len(self):
self.mem = pyrtl.MemBlock(8, 1)
self.assertEqual(len(self.mem[0]), 8)
self.mem_2 = pyrtl.MemBlock(16, 1)
self.assertEqual(len(self.mem_2[0]), 16)

def test_memindexed_getitem(self):
mem = pyrtl.MemBlock(bitwidth=8, addrwidth=1, max_read_ports=None)
mem_in = pyrtl.Input(1, 'mem_in')
mem_out_array = [pyrtl.Output(8, 'mem_out_' + str(i)) for i in range(8)]
for i in range(8):
mem_out_array[i] <<= mem[mem_in][i]
mem_value_map = {mem: {0: 7, 1: 5}}
sim_trace = pyrtl.SimulationTrace()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a SimulationTrace shouldn't be needed here since Simulation should create a SimulationTrace for you if none is provided (happens again in the next two tests)

sim = pyrtl.Simulation(tracer=sim_trace, memory_value_map=mem_value_map)
for i in range(len(mem_value_map[mem])):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like we gain much by testing both addresses? Consider just checking the first address as loops like these make it harder to debug the test. When a test failure occurs, we won't know the values of i or j

sim.step({mem_in: i})
binary = bin(mem_value_map[mem][i])[2:].zfill(8)
for j in range(8):
self.assertEqual(sim.inspect(mem_out_array[j]), int(binary[7 - j]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, it's a little better to reassemble the inspected binary value, and then do one big assertEqual, rather than doing one assertEqual for each bit in the loop. With one big assertEqual, the test failure error message will be better, since we can more easily see which bits are wrong. With the single-bit assertEqual, the error message will just tell us that 0 is not equal to 1 or vice versa.

You can reassemble the inspected values with something like:

>>> def reassemble_bits(bits: list[int]):
...     value = 0
...     for bit in bits:
...         value = (value << 1) | bit
...     return value
>>> reassemble_bits([0, 0])
0
>>> reassemble_bits([0, 1])
1
>>> reassemble_bits([1, 0])
2
>>> reassemble_bits([1, 1])
3

If you want to keep things the way they are, consider moving this logic to check if the Nth bit is set to a helper function, as it's a bit complicated and it's repeated in the next two tests.

Also you can check this a little more easily with bitmasks and bitwise AND, something like

>>> def check_nth_bit(value: int, n: int):
...     return bool(value & (1 << n))
>>> for n in range(8):
...     print(n, check_nth_bit(27, n))
...
0 True
1 True
2 False
3 True
4 True
5 False
6 False
7 False


def test_memindexed_sign_extended(self):
mem = pyrtl.MemBlock(bitwidth=8, addrwidth=1)
mem_in = pyrtl.Input(1, 'mem_in')
mem_out = pyrtl.Output(16, 'mem_out')
mem_out <<= mem[mem_in].sign_extended(16)
mem_value_map = {mem: {0: 0b00101101, 1: 0b10011011}}
mem_value_map_sign_extended = [0b0000000000101101, 0b1111111110011011]
sim_trace = pyrtl.SimulationTrace()
sim = pyrtl.Simulation(tracer=sim_trace, memory_value_map=mem_value_map)
for i in range(len(mem_value_map[mem])):
sim.step({mem_in: i})
self.assertEqual(sim.inspect(mem_out), mem_value_map_sign_extended[i])

def test_memindexed_zero_extended(self):
mem = pyrtl.MemBlock(bitwidth=8, addrwidth=1)
mem_in = pyrtl.Input(1, 'mem_in')
mem_out = pyrtl.Output(16, 'mem_out')
mem_out <<= mem[mem_in].zero_extended(16)
mem_value_map = {mem: {0: 0b00101101, 1: 0b10011011}}
mem_value_map_zero_extended = [0b0000000000101101, 0b0000000010011011]
sim_trace = pyrtl.SimulationTrace()
sim = pyrtl.Simulation(tracer=sim_trace, memory_value_map=mem_value_map)
for i in range(len(mem_value_map[mem])):
sim.step({mem_in: i})
self.assertEqual(sim.inspect(mem_out), mem_value_map_zero_extended[i])


class RTLRomBlockWiring(unittest.TestCase):
data = list(range(2**5))
Expand Down
Loading