Skip to content

Commit 7b357f1

Browse files
authored
Merge pull request #371 from pllab/matrix_helper
Move and improve list_to_int helper from test_matrix.py to matrix.py file
2 parents 164adc6 + 0b092a1 commit 7b357f1

File tree

2 files changed

+76
-10
lines changed

2 files changed

+76
-10
lines changed

pyrtl/rtllib/matrix.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..wire import Const, WireVector
77
from ..corecircuits import as_wires, concat, select
88
from ..pyrtlexceptions import PyrtlError
9+
from ..helperfuncs import formatted_str_to_val
910

1011

1112
class Matrix(object):
@@ -1342,3 +1343,58 @@ def matrix_wv_to_list(matrix_wv, rows, columns, bits):
13421343
result[i][j] = int_value
13431344
bit_pointer += bits
13441345
return result
1346+
1347+
1348+
def list_to_int(matrix, n_bits):
1349+
''' Convert a Python matrix (a list of lists) into an integer.
1350+
1351+
:param list[list[int]] matrix: a pure Python list of lists representing a matrix
1352+
:param int n_bits: number of bits to be used to represent each element; if an
1353+
element doesn't fit in n_bits, it truncates the most significant bits.
1354+
:return int: a N*n_bits wide wirevector containing the elements of `matrix`,
1355+
where N is the number of elements in `matrix`
1356+
1357+
Integers that are signed will automatically be converted to their two's complement form.
1358+
1359+
This function is helpful for turning a pure Python list of lists
1360+
into a integer suitable for creating a Constant wirevector that can
1361+
be passed in to as a Matrix intializer's `value` argument, or for
1362+
passing into a Simulation's step function for a particular input wire.
1363+
1364+
For example, calling Matrix.list_to_int([3, 5], [7, 9], 4) produces 13,689,
1365+
which in binary looks like this::
1366+
1367+
0011 0101 0111 1001
1368+
1369+
Note how the elements of the list of lists were added, 4 bits at a time,
1370+
in row order, such that the element at row 0, column 0 is in the most significant
1371+
4 bits, and the element at row 1, column 1 is in the least significant 4 bits.
1372+
1373+
Here's an example of using it in simulation::
1374+
1375+
a_vals = [[0, 1], [2, 3]]
1376+
b_vals = [[2, 4, 6], [8, 10, 12]]
1377+
1378+
a_in = pyrtl.Input(4 * 4, 'a_in')
1379+
b_in = pyrtl.Input(6 * 4, 'b_in')
1380+
a = Matrix.Matrix(2, 2, 4, value=a_in)
1381+
b = Matrix.Matrix(2, 3, 4, value=b_in)
1382+
...
1383+
1384+
sim = pyrtl.Simulation()
1385+
sim.step({
1386+
'a_in': Matrix.list_to_int(a_vals)
1387+
'b_in': Matrix.list_to_int(b_vals)
1388+
})
1389+
'''
1390+
if n_bits <= 0:
1391+
raise PyrtlError("Number of bits per element must be positive, instead got %d" % n_bits)
1392+
1393+
result = 0
1394+
1395+
for i in range(len(matrix)):
1396+
for j in range(len(matrix[0])):
1397+
val = formatted_str_to_val(str(matrix[i][j]), 's' + str(n_bits))
1398+
result = (result << n_bits) | val
1399+
1400+
return result

tests/rtllib/test_matrix.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_init_random_wirevector(self):
159159
self.init_wirevector(matrix, rows, columns, bits)
160160

161161
def init_wirevector(self, matrix_value, rows, columns, bits):
162-
matrix_input = pyrtl.Const(matrix_to_int(matrix_value, bits), rows * columns * bits)
162+
matrix_input = pyrtl.Const(Matrix.list_to_int(matrix_value, bits), rows * columns * bits)
163163
matrix = Matrix.Matrix(rows, columns, bits, value=matrix_input)
164164

165165
self.assertEqual(rows, matrix.rows)
@@ -2445,19 +2445,29 @@ def test_vstack_on_non_matrices_fails(self):
24452445
_v = Matrix.vstack(w, m)
24462446

24472447

2448-
'''
2449-
These are helpful functions to use in testing
2450-
'''
2448+
class TestHelpers(unittest.TestCase):
2449+
def setUp(self):
2450+
pyrtl.reset_working_block()
24512451

2452+
def test_list_to_int(self):
2453+
self.assertEquals(Matrix.list_to_int([[0]], 1), 0b0)
2454+
self.assertEquals(Matrix.list_to_int([[1, 2]], 2), 0b0110)
2455+
self.assertEquals(Matrix.list_to_int([[1, 2, 3]], 2), 0b011011)
2456+
self.assertEquals(Matrix.list_to_int([[4, 9, 11], [3, 5, 6]], 4),
2457+
0b010010011011001101010110)
24522458

2453-
def matrix_to_int(matrix, n_bits):
2454-
result = ''
2459+
def test_list_to_int_truncates(self):
2460+
self.assertEquals(Matrix.list_to_int([[4, 9, 27]], 3), 0b100001011)
24552461

2456-
for i in range(len(matrix)):
2457-
for j in range(len(matrix[0])):
2458-
result = result + bin(matrix[i][j])[2:].zfill(n_bits)
2462+
def test_list_to_int_negative(self):
2463+
self.assertEquals(Matrix.list_to_int([[-4, -9, 11]], 5), 0b111001011101011)
24592464

2460-
return int(result, 2)
2465+
def test_list_to_int_negative_truncates(self):
2466+
self.assertEquals(Matrix.list_to_int([[-4, -9, 11]], 3), 0b100111011)
2467+
2468+
def test_list_to_int_non_positive_n_bits(self):
2469+
with self.assertRaises(pyrtl.PyrtlError):
2470+
Matrix.list_to_int([[3]], 0)
24612471

24622472

24632473
if __name__ == '__main__':

0 commit comments

Comments
 (0)