Skip to content

Commit 164adc6

Browse files
authored
Merge pull request #369 from pllab/matrix_update
Fix Matrix, flatten(), add reshape() and put() methods
2 parents 4e02e77 + 7c2fe05 commit 164adc6

File tree

2 files changed

+480
-26
lines changed

2 files changed

+480
-26
lines changed

pyrtl/rtllib/matrix.py

Lines changed: 165 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,23 @@ def __getitem__(self, key):
213213
raise PyrtlError('Rows must be of type int or slice, '
214214
'instead "%s" was passed of type %s' %
215215
(str(rows), type(rows)))
216+
if rows < 0:
217+
rows = self.rows - abs(rows)
218+
if rows < 0:
219+
raise PyrtlError("Invalid bounds for rows. Max rows: %s, got: %s" % (
220+
str(self.rows), str(rows)))
216221
rows = slice(rows, rows + 1, 1)
217222

218223
if not isinstance(columns, slice):
219224
if not isinstance(columns, int):
220225
raise PyrtlError('Columns must be of type int or slice, '
221226
'instead "%s" was passed of type %s' %
222227
(str(columns), type(columns)))
228+
if columns < 0:
229+
columns = self.columns - abs(columns)
230+
if columns < 0:
231+
raise PyrtlError("Invalid bounds for columns. Max columns: %s, got: %s" % (
232+
str(self.columns), str(columns)))
223233
columns = slice(columns, columns + 1, 1)
224234

225235
if rows.start is None:
@@ -665,30 +675,169 @@ def pow_2(first, second):
665675

666676
raise PyrtlError('Power must be greater than or equal to 0')
667677

668-
def flatten(self, order='C'):
669-
''' Flatten the matrix into a single row.
678+
def put(self, ind, v, mode='raise'):
679+
''' Replace specified elements of the matrix with given values
670680
671-
:param str order: 'C' means row-major order (C-style), and
672-
'F' means column-major order (Fortran-style).
673-
:return: row vector matrix
681+
:param int/list[int]/tuple[int] ind: target indices
682+
:param int/list[int]/tuple[int]/Matrix row-vector v: values to place in
683+
matrix at target indices; if v is shorter than ind, it is repeated as necessary
684+
:param str mode: how out-of-bounds indices behave; 'raise' raises an
685+
error, 'wrap' wraps aoround, and 'clip' clips to the range
686+
687+
Note that the index is on the flattened matrix.
674688
'''
675-
result = []
676-
if order == 'C':
677-
for r in range(self.rows):
678-
for c in range(self.columns):
679-
result.append(as_wires(self[r, c], bitwidth=self.bits))
680-
elif order == 'F':
681-
for c in range(self.columns):
682-
for r in range(self.rows):
683-
result.append(as_wires(self[r, c], bitwidth=self.bits))
689+
count = self.rows * self.columns
690+
if isinstance(ind, int):
691+
ind = (ind,)
692+
elif not isinstance(ind, (tuple, list)):
693+
raise PyrtlError("Expected int or list-like indices, got %s" % type(ind))
694+
695+
if isinstance(v, int):
696+
v = (v,)
697+
698+
if isinstance(v, (tuple, list)) and len(v) == 0:
699+
return
700+
elif isinstance(v, Matrix):
701+
if v.rows != 1:
702+
raise PyrtlError(
703+
"Expected a row-vector matrix, instead got matrix with %d rows" % v.rows
704+
)
705+
706+
if mode not in ['raise', 'wrap', 'clip']:
707+
raise PyrtlError(
708+
"Unexpected mode %s; allowable modes are 'raise', 'wrap', and 'clip'" % mode
709+
)
710+
711+
def get_ix(ix):
712+
if ix < 0:
713+
ix = count - abs(ix)
714+
if ix < 0 or ix >= count:
715+
if mode == 'raise':
716+
raise PyrtlError("index %d is out of bounds with size %d" % (ix, count))
717+
elif mode == 'wrap':
718+
ix = ix % count
719+
elif mode == 'clip':
720+
ix = 0 if ix < 0 else count - 1
721+
return ix
722+
723+
def get_value(ix):
724+
if isinstance(v, (tuple, list)):
725+
if ix >= len(v):
726+
return v[-1] # if v is shorter than ind, repeat last as necessary
727+
return v[ix]
728+
elif isinstance(v, Matrix):
729+
if ix >= count:
730+
return v[0, -1]
731+
return v[0, ix]
732+
733+
for v_ix, mat_ix in enumerate(ind):
734+
mat_ix = get_ix(mat_ix)
735+
row = mat_ix // self.columns
736+
col = mat_ix % self.columns
737+
self[row, col] = get_value(v_ix)
738+
739+
def reshape(self, *newshape, order='C'):
740+
''' Create a matrix of the given shape from the current matrix.
741+
742+
:param int/ints/tuple[int] newshape: shape of the matrix to return;
743+
if a single int, will result in a 1-D row-vector of that length;
744+
if a tuple, will use values for number of rows and cols. Can also
745+
be a varargs.
746+
:param str order: 'C' means to read from self using
747+
row-major order (C-style), and 'F' means to read from self
748+
using column-major order (Fortran-style).
749+
:return: A copy of the matrix with same data, with a new number of rows/cols
750+
751+
One shape dimension in newshape can be -1; in this case, the value
752+
for that dimension is inferred from the other given dimension (if any)
753+
and the number of elements in the matrix.
754+
755+
Examples::
756+
int_matrix = [[0, 1, 2, 3], [4, 5, 6, 7]]
757+
matrix = Matrix.Matrix(2, 4, 4, value=int_matrix)
758+
759+
matrix.reshape(-1) == [[0, 1, 2, 3, 4, 5, 6, 7]]
760+
matrix.reshape(8) == [[0, 1, 2, 3, 4, 5, 6, 7]]
761+
matrix.reshape(1, 8) == [[0, 1, 2, 3, 4, 5, 6, 7]]
762+
matrix.reshape((1, 8)) == [[0, 1, 2, 3, 4, 5, 6, 7]]
763+
matrix.reshape((1, -1)) == [[0, 1, 2, 3, 4, 5, 6, 7]]
764+
765+
matrix.reshape(4, 2) == [[0, 1], [2, 3], [4, 5], [6, 7]]
766+
matrix.reshape(-1, 2) == [[0, 1], [2, 3], [4, 5], [6, 7]]
767+
matrix.reshape(4, -1) == [[0, 1], [2, 3], [4, 5], [6, 7]]
768+
'''
769+
count = self.rows * self.columns
770+
if isinstance(newshape, int):
771+
if newshape == -1:
772+
newshape = (1, count)
773+
else:
774+
newshape = (1, newshape)
775+
elif isinstance(newshape, tuple):
776+
if isinstance(newshape[0], tuple):
777+
newshape = newshape[0]
778+
if len(newshape) == 1:
779+
newshape = (1, newshape[0])
780+
if len(newshape) > 2:
781+
raise PyrtlError("length of newshape tuple must be <= 2")
782+
rows, cols = newshape
783+
if not isinstance(rows, int) or not isinstance(cols, int):
784+
raise PyrtlError(
785+
"newshape dimensions must be integers, instead got %s" % type(newshape)
786+
)
787+
if rows == -1 and cols == -1:
788+
raise PyrtlError("Both dimensions in newshape cannot be -1")
789+
if rows == -1:
790+
rows = count // cols
791+
newshape = (rows, cols)
792+
elif cols == -1:
793+
cols = count // rows
794+
newshape = (rows, cols)
684795
else:
796+
raise PyrtlError(
797+
"newshape can be an integer or tuple of integers, not %s" % type(newshape)
798+
)
799+
800+
rows, cols = newshape
801+
if rows * cols != count:
802+
raise PyrtlError(
803+
"Cannot reshape matrix of size %d into shape %s" % (count, str(newshape))
804+
)
805+
806+
if order not in 'CF':
685807
raise PyrtlError(
686808
"Invalid order %s. Acceptable orders are 'C' (for row-major C-style order) "
687809
"and 'F' (for column-major Fortran-style order)." % order
688810
)
689811

690-
return Matrix(rows=1, columns=len(result), bits=self.bits,
691-
value=[result], max_bits=self.max_bits)
812+
value = [[0] * cols for _ in range(rows)]
813+
ix = 0
814+
if order == 'C':
815+
# Read and write in row-wise order
816+
for newr in range(rows):
817+
for newc in range(cols):
818+
r = ix // self.columns
819+
c = ix % self.columns
820+
value[newr][newc] = self[r, c]
821+
ix += 1
822+
else:
823+
# Read and write in column-wise order
824+
for newc in range(cols):
825+
for newr in range(rows):
826+
r = ix % self.rows
827+
c = ix // self.rows
828+
value[newr][newc] = self[r, c]
829+
ix += 1
830+
831+
return Matrix(rows, cols, self.bits, self.signed, value, self.max_bits)
832+
833+
def flatten(self, order='C'):
834+
''' Flatten the matrix into a single row.
835+
836+
:param str order: 'C' means row-major order (C-style), and
837+
'F' means column-major order (Fortran-style).
838+
:return: A copy of the matrix flattened in to a row vector matrix
839+
'''
840+
return self.reshape(self.rows * self.columns, order=order)
692841

693842

694843
def multiply(first, second):

0 commit comments

Comments
 (0)