10
10
11
11
12
12
class Matrix (object ):
13
- '''Class for making a Matrix using PyRTL.
13
+ ''' Class for making a Matrix using PyRTL.
14
14
15
15
Provides the ability to perform different matrix operations.
16
16
'''
@@ -21,18 +21,18 @@ class Matrix(object):
21
21
def __init__ (self , rows , columns , bits , signed = False , value = None , max_bits = 64 ):
22
22
''' Constructs a Matrix object.
23
23
24
- :param int rows: the number of rows in the matrix. Must be greater than 0.
25
- :param int columns: the number of columns in the matrix. Must be greater than 0.
26
- :param int bits: The amount of bits per wirevector. Must be greater than 0.
24
+ :param int rows: the number of rows in the matrix. Must be greater than 0
25
+ :param int columns: the number of columns in the matrix. Must be greater than 0
26
+ :param int bits: The amount of bits per wirevector. Must be greater than 0
27
27
:param bool signed: Currently not supported (will be added in the future)
28
28
:param (WireVector/list) value: The value you want to initialize the Matrix with.
29
29
If a WireVector, must be of size `rows * columns * bits`. If a list, must have
30
30
`rows` rows and `columns` columns, and every element must fit in `bits` size.
31
- If not given, the matrix initializes to 0.
31
+ If not given, the matrix initializes to 0
32
32
:param int max_bits: The maximum number of bits each wirevector can have, even
33
33
after operations like adding two matrices together results in larger
34
- resulting wirevectors.
35
- :return: a constructed Matrix object.
34
+ resulting wirevectors
35
+ :return: a constructed Matrix object
36
36
'''
37
37
if not isinstance (rows , int ):
38
38
raise PyrtlError ('Rows must be of type int, instead "%s" '
@@ -107,15 +107,15 @@ def __init__(self, rows, columns, bits, signed=False, value=None, max_bits=64):
107
107
def bits (self ):
108
108
''' Gets the number of bits each value is allowed to hold.
109
109
110
- :return: an integer representing the number of bits.
110
+ :return: an integer representing the number of bits
111
111
'''
112
112
return self ._bits
113
113
114
114
@bits .setter
115
115
def bits (self , bits ):
116
116
''' Sets the number of bits.
117
117
118
- :param int bits: The number of bits. Must be greater than 0.
118
+ :param int bits: The number of bits. Must be greater than 0
119
119
120
120
Called automatically when bits is changed.
121
121
NOTE: This function will truncate the most significant bits.
@@ -139,17 +139,17 @@ def __len__(self):
139
139
140
140
:return: an integer representing the output WireVector bitwidth
141
141
142
- Used with default len() function
142
+ Used with default ` len()` function
143
143
'''
144
144
return self .bits * self .rows * self .columns
145
145
146
146
def to_wirevector (self ):
147
147
''' Outputs the PyRTL Matrix as a singular concatenated Wirevector.
148
148
149
- :return: a Wirevector representing the whole PyRTL matrix.
149
+ :return: a Wirevector representing the whole PyRTL matrix
150
150
151
- For instance, if we had a 2 x 1 matrix [[wire_a, wire_b]] it would
152
- return the concatenated wire: wire = wire_a.wire_b
151
+ For instance, if we had a 2 x 1 matrix ` [[wire_a, wire_b]]` it would
152
+ return the concatenated wire: ` wire = wire_a.wire_b`
153
153
'''
154
154
result = []
155
155
@@ -162,7 +162,7 @@ def to_wirevector(self):
162
162
def transpose (self ):
163
163
''' Constructs the transpose of the matrix
164
164
165
- :return: a Matrix object representing the transpose.
165
+ :return: a Matrix object representing the transpose
166
166
'''
167
167
result = Matrix (self .columns , self .rows , self .bits , max_bits = self .max_bits )
168
168
for i in range (result .rows ):
@@ -173,7 +173,7 @@ def transpose(self):
173
173
def __reversed__ (self ):
174
174
''' Constructs the reverse of matrix
175
175
176
- :return: a Matrix object representing the reverse.
176
+ :return: a Matrix object representing the reverse
177
177
178
178
Used with the reversed() method
179
179
'''
@@ -315,8 +315,7 @@ def __setitem__(self, key, value):
315
315
:param (slice/int rows, slice/int columns) key: The key value to set
316
316
:param Wirevector/int/Matrix value: The value in which to set the key
317
317
318
- Called when setting a value using square brackets.
319
- (e.g. matrix[a, b] = value)
318
+ Called when setting a value using square brackets (e.g. `matrix[a, b] = value`).
320
319
321
320
The value given will be truncated to match the bitwidth of all the elements
322
321
in the matrix.
@@ -434,9 +433,9 @@ def copy(self):
434
433
def __iadd__ (self , other ):
435
434
''' Perform the in-place addition operation.
436
435
437
- :return: a Matrix object with the element wise addition being preformed.
436
+ :return: a Matrix object with the elementwise addition being preformed
438
437
439
- Is used with a += b. Performs an elementwise addition.
438
+ Is used with ` a += b` . Performs an elementwise addition.
440
439
'''
441
440
new_value = (self + other )
442
441
self ._matrix = new_value ._matrix
@@ -446,9 +445,9 @@ def __iadd__(self, other):
446
445
def __add__ (self , other ):
447
446
''' Perform the addition operation.
448
447
449
- :return: a Matrix object with the element wise addition being performed.
448
+ :return: a Matrix object with the element wise addition being performed
450
449
451
- Is used with a + b. Performs an elementwise addition.
450
+ Is used with ` a + b` . Performs an elementwise addition.
452
451
'''
453
452
if not isinstance (other , Matrix ):
454
453
raise PyrtlError ('error: expecting a Matrix, '
@@ -478,9 +477,9 @@ def __isub__(self, other):
478
477
''' Perform the inplace subtraction opperation.
479
478
480
479
:Matrix other: the PyRTL Matrix to subtract
481
- :return: a Matrix object with the element wise subtraction being performed.
480
+ :return: a Matrix object with the element wise subtraction being performed
482
481
483
- Is used with a -= b. Performs an elementwise subtraction.
482
+ Is used with ` a -= b` . Performs an elementwise subtraction.
484
483
'''
485
484
new_value = self - other
486
485
self ._matrix = new_value ._matrix
@@ -491,11 +490,11 @@ def __sub__(self, other):
491
490
''' Perform the subtraction operation.
492
491
493
492
:Matrix other: the PyRTL Matrix to subtract
494
- :return: a Matrix object with the elementwise subtraction being performed.
493
+ :return: a Matrix object with the elementwise subtraction being performed
495
494
496
- Is used with a - b. Performs an elementwise subtraction.
495
+ Is used with ` a - b` . Performs an elementwise subtraction.
497
496
498
- Note: If using unsigned numbers, the result will be floored at 0
497
+ Note: If using unsigned numbers, the result will be floored at 0.
499
498
'''
500
499
if not isinstance (other , Matrix ):
501
500
raise PyrtlError ('error: expecting a Matrix, '
@@ -531,10 +530,10 @@ def __sub__(self, other):
531
530
def __imul__ (self , other ):
532
531
''' Perform the in-place multiplication operation.
533
532
534
- :Matrix/Wirevector other: the Matrix or scalar to multiply
535
- :return: a Matrix object with the resulting multiplication operation being preformed.
533
+ :param Matrix/Wirevector other: the Matrix or scalar to multiply
534
+ :return: a Matrix object with the resulting multiplication operation being preformed
536
535
537
- Is used with a *= b. Performs an elementwise or scalar multiplication.
536
+ Is used with ` a *= b` . Performs an elementwise or scalar multiplication.
538
537
'''
539
538
new_value = self * other
540
539
self ._matrix = new_value ._matrix
@@ -544,10 +543,10 @@ def __imul__(self, other):
544
543
def __mul__ (self , other ):
545
544
''' Perform the elementwise or scalar multiplication operation.
546
545
547
- :Matrix/Wirevector other: the Matrix to multiply
548
- :return: a Matrix object with the resulting multiplication operation being performed.
546
+ :param Matrix/Wirevector other: the Matrix to multiply
547
+ :return: a Matrix object with the resulting multiplication operation being performed
549
548
550
- Is used with a * b.
549
+ Is used with ` a * b` .
551
550
'''
552
551
553
552
if isinstance (other , Matrix ):
@@ -583,7 +582,7 @@ def __imatmul__(self, other):
583
582
:param Matrix other: the second matrix.
584
583
:return: a PyRTL Matrix that contains the matrix multiplication product of this and other
585
584
586
- Is used with a @= b
585
+ Is used with ` a @= b`.
587
586
588
587
Note: The matmul symbol (@) only works in python 3.5+. Otherwise you must
589
588
call `__imatmul__(other)`.
@@ -601,7 +600,7 @@ def __matmul__(self, other):
601
600
:param Matrix other: the second matrix.
602
601
:return: a PyRTL Matrix that contains the matrix multiplication product of this and other
603
602
604
- Is used with a @ b
603
+ Is used with ` a @ b`.
605
604
606
605
Note: The matmul symbol (@) only works in python 3.5+. Otherwise you must
607
606
call `__matmul__(other)`.
@@ -633,7 +632,7 @@ def __ipow__(self, power):
633
632
:param int power: the power to perform the matrix on
634
633
:return: a PyRTL Matrix that contains the matrix power product
635
634
636
- Is used with a **= b
635
+ Is used with ` a **= b`.
637
636
'''
638
637
new_value = self ** power
639
638
self ._matrix = new_value ._matrix
@@ -646,7 +645,7 @@ def __pow__(self, power):
646
645
:param int power: the power to perform the matrix on
647
646
:return: a PyRTL Matrix that contains the matrix power product
648
647
649
- Is used with a ** b
648
+ Is used with ` a ** b`.
650
649
'''
651
650
if not isinstance (power , int ):
652
651
raise PyrtlError ('Unexpected power given. Type int expected, '
@@ -754,6 +753,7 @@ def reshape(self, *newshape, order='C'):
754
753
and the number of elements in the matrix.
755
754
756
755
Examples::
756
+
757
757
int_matrix = [[0, 1, 2, 3], [4, 5, 6, 7]]
758
758
matrix = Matrix.Matrix(2, 4, 4, value=int_matrix)
759
759
@@ -835,7 +835,7 @@ def flatten(self, order='C'):
835
835
''' Flatten the matrix into a single row.
836
836
837
837
:param str order: 'C' means row-major order (C-style), and
838
- 'F' means column-major order (Fortran-style).
838
+ 'F' means column-major order (Fortran-style)
839
839
:return: A copy of the matrix flattened in to a row vector matrix
840
840
'''
841
841
return self .reshape (self .rows * self .columns , order = order )
@@ -844,9 +844,9 @@ def flatten(self, order='C'):
844
844
def multiply (first , second ):
845
845
''' Perform the elementwise or scalar multiplication operation.
846
846
847
- :param Matrix first: first matrix.
848
- :param Matrix/Wirevector second: second matrix.
849
- :return: a Matrix object with the element wise or scaler multiplication being performed.
847
+ :param Matrix first: first matrix
848
+ :param Matrix/Wirevector second: second matrix
849
+ :return: a Matrix object with the element wise or scaler multiplication being performed
850
850
'''
851
851
if not isinstance (first , Matrix ):
852
852
raise PyrtlError ('error: expecting a Matrix, '
@@ -858,9 +858,9 @@ def sum(matrix, axis=None, bits=None):
858
858
''' Returns the sum of all the values in a matrix
859
859
860
860
:param Matrix/Wirevector matrix: the matrix to perform sum operation on.
861
- If it is a WireVector, it will return itself.
862
- :param None/int axis: The axis to perform the operation on.
863
- None refers to sum of all item. 0 is sum of column. 1 is sum of rows. Defaults to None.
861
+ If it is a WireVector, it will return itself
862
+ :param None/int axis: The axis to perform the operation on
863
+ None refers to sum of all item. 0 is sum of column. 1 is sum of rows. Defaults to None
864
864
:param int bits: The bits per value of the sum. Defaults to bits of old matrix
865
865
:return: A wirevector or Matrix representing sum
866
866
'''
@@ -922,9 +922,9 @@ def min(matrix, axis=None, bits=None):
922
922
''' Returns the minimum value in a matrix.
923
923
924
924
:param Matrix/Wirevector matrix: the matrix to perform min operation on.
925
- If it is a WireVector, it will return itself.
926
- :param None/int axis: The axis to perform the operation on.
927
- None refers to min of all item. 0 is min of column. 1 is min of rows. Defaults to None.
925
+ If it is a WireVector, it will return itself
926
+ :param None/int axis: The axis to perform the operation on
927
+ None refers to min of all item. 0 is min of column. 1 is min of rows. Defaults to None
928
928
:param int bits: The bits per value of the min. Defaults to bits of old matrix
929
929
:return: A WireVector or Matrix representing the min value
930
930
'''
@@ -986,10 +986,10 @@ def max(matrix, axis=None, bits=None):
986
986
''' Returns the max value in a matrix.
987
987
988
988
:param Matrix/Wirevector matrix: the matrix to perform max operation on.
989
- If it is a wirevector, it will return itself.
990
- :param None/int axis: The axis to perform the operation on.
989
+ If it is a wirevector, it will return itself
990
+ :param None/int axis: The axis to perform the operation on
991
991
None refers to max of all items. 0 is max of the columns. 1 is max of rows.
992
- Defaults to None.
992
+ Defaults to None
993
993
:param int bits: The bits per value of the max. Defaults to bits of old matrix
994
994
:return: A WireVector or Matrix representing the max value
995
995
'''
@@ -1053,10 +1053,10 @@ def argmax(matrix, axis=None, bits=None):
1053
1053
''' Returns the index of the max value of the matrix.
1054
1054
1055
1055
:param Matrix/Wirevector matrix: the matrix to perform argmax operation on.
1056
- If it is a WireVector, it will return itself.
1056
+ If it is a WireVector, it will return itself
1057
1057
:param None/int axis: The axis to perform the operation on.
1058
1058
None refers to argmax of all items. 0 is argmax of the columns. 1 is argmax of rows.
1059
- Defaults to None.
1059
+ Defaults to None
1060
1060
:param int bits: The bits per value of the argmax. Defaults to bits of old matrix
1061
1061
:return: A WireVector or Matrix representing the argmax value
1062
1062
@@ -1126,9 +1126,9 @@ def argmax(matrix, axis=None, bits=None):
1126
1126
def dot (first , second ):
1127
1127
''' Performs the dot product on two matrices.
1128
1128
1129
- :param Matrix first: the first matrix.
1130
- :param Matrix second: the second matrix.
1131
- :return: a PyRTL Matrix that contains the dot product of the two PyRTL Matrices.
1129
+ :param Matrix first: the first matrix
1130
+ :param Matrix second: the second matrix
1131
+ :return: a PyRTL Matrix that contains the dot product of the two PyRTL Matrices
1132
1132
1133
1133
Specifically, the dot product on two matrices is
1134
1134
* If either first or second are WireVectors/have both rows and columns
@@ -1182,15 +1182,15 @@ def hstack(*matrices):
1182
1182
1183
1183
All the matrices must have the same number of rows and same 'signed' value.
1184
1184
1185
- For example:
1185
+ For example::
1186
1186
1187
1187
m1 = Matrix(2, 3, bits=5, value=[[1,2,3],
1188
1188
[4,5,6]])
1189
1189
m2 = Matrix(2, 1, bits=10, value=[[17],
1190
1190
[23]]])
1191
1191
m3 = hstack(m1, m2)
1192
1192
1193
- m3 looks like:
1193
+ m3 looks like::
1194
1194
1195
1195
[[1,2,3,17],
1196
1196
[4,5,6,23]]
@@ -1236,14 +1236,14 @@ def vstack(*matrices):
1236
1236
1237
1237
All the matrices must have the same number of columns and same 'signed' value.
1238
1238
1239
- For example:
1239
+ For example::
1240
1240
1241
1241
m1 = Matrix(2, 3, bits=5, value=[[1,2,3],
1242
1242
[4,5,6]])
1243
1243
m2 = Matrix(1, 3, bits=10, value=[[7,8,9]])
1244
1244
m3 = vstack(m1, m2)
1245
1245
1246
- m3 looks like:
1246
+ m3 looks like::
1247
1247
1248
1248
[[1,2,3],
1249
1249
[4,5,6],
@@ -1350,7 +1350,7 @@ def list_to_int(matrix, n_bits):
1350
1350
1351
1351
:param list[list[int]] matrix: a pure Python list of lists representing a matrix
1352
1352
:param int n_bits: number of bits to be used to represent each element; if an
1353
- element doesn't fit in n_bits, it truncates the most significant bits.
1353
+ element doesn't fit in n_bits, it truncates the most significant bits
1354
1354
:return int: a N*n_bits wide wirevector containing the elements of `matrix`,
1355
1355
where N is the number of elements in `matrix`
1356
1356
0 commit comments