Skip to content

Commit b4902ec

Browse files
committed
Add doctest examples for rtllib/multipliers.py. Also fix bugs in
`simple_mult` and `complex_mult` where they would signal `done` before `start` due to initially empty registers.
1 parent 49eefa0 commit b4902ec

File tree

2 files changed

+147
-14
lines changed

2 files changed

+147
-14
lines changed

pyrtl/rtllib/multipliers.py

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@ def simple_mult(
2323
Requires very small area (it uses only a single adder), but has long delay (worst
2424
case is ``len(A)`` cycles).
2525
26+
.. doctest only::
27+
28+
>>> import pyrtl
29+
>>> pyrtl.reset_working_block()
30+
31+
Example::
32+
33+
>>> a = pyrtl.Input(name="a", bitwidth=4)
34+
>>> b = pyrtl.Input(name="b", bitwidth=4)
35+
>>> start = pyrtl.Input(name="start", bitwidth=1)
36+
37+
>>> output, done = pyrtl.rtllib.multipliers.simple_mult(a, b, start=start)
38+
>>> output.name = "output"
39+
>>> done.name = "done"
40+
41+
>>> sim = pyrtl.Simulation()
42+
>>> sim.step({"a": 2, "b": 3, "start": True})
43+
>>> while not sim.inspect("done"):
44+
... sim.step({"a": 0, "b": 0, "start": False})
45+
>>> sim.inspect("output")
46+
6
47+
2648
:param A: Input wire for the multiplication.
2749
:param B: Input wire for the multiplication.
2850
:param start: A one-bit input that indicates when the inputs are ready.
@@ -38,21 +60,23 @@ def simple_mult(
3860
areg = pyrtl.Register(alen)
3961
breg = pyrtl.Register(blen + alen)
4062
accum = pyrtl.Register(blen + alen)
41-
done = areg == 0 # Multiplication is finished when a becomes 0
63+
done = pyrtl.WireVector(bitwidth=1)
4264

4365
# During multiplication, shift a right every cycle, b left every cycle
4466
with pyrtl.conditional_assignment:
4567
with start: # initialization
4668
areg.next |= A
4769
breg.next |= B
4870
accum.next |= 0
49-
with ~done: # don't run when there's no work to do
71+
with areg != 0: # don't run when there's no work to do
5072
areg.next |= areg[1:] # right shift
5173
breg.next |= pyrtl.concat(breg, pyrtl.Const(0, 1)) # left shift
5274
a_0_val = areg[0].sign_extended(len(accum))
5375

5476
# adds to accum only when LSB of areg is 1
5577
accum.next |= accum + (a_0_val & breg)
78+
with pyrtl.otherwise:
79+
done |= True
5680

5781
return accum, done
5882

@@ -82,6 +106,29 @@ def complex_mult(
82106
"""Generate shift-and-add multiplier that can shift and add multiple bits per clock
83107
cycle. Uses substantially more space than :func:`simple_mult` but is much faster.
84108
109+
.. doctest only::
110+
111+
>>> import pyrtl
112+
>>> pyrtl.reset_working_block()
113+
114+
Example::
115+
116+
>>> a = pyrtl.Input(name="a", bitwidth=4)
117+
>>> b = pyrtl.Input(name="b", bitwidth=4)
118+
>>> start = pyrtl.Input(name="start", bitwidth=1)
119+
120+
>>> output, done = pyrtl.rtllib.multipliers.complex_mult(
121+
... a, b, shifts=2, start=start)
122+
>>> output.name = "output"
123+
>>> done.name = "done"
124+
125+
>>> sim = pyrtl.Simulation()
126+
>>> sim.step({"a": 2, "b": 3, "start": True})
127+
>>> while not sim.inspect("done"):
128+
... sim.step({"a": 0, "b": 0, "start": False})
129+
>>> sim.inspect("output")
130+
6
131+
85132
:param A: Input wire for the multiplication.
86133
:param B: Input wire for the multiplication.
87134
:param shifts: Number of spaces :class:`.Register` is to be shifted per clock cycle.
@@ -96,7 +143,7 @@ def complex_mult(
96143
areg = pyrtl.Register(alen)
97144
breg = pyrtl.Register(alen + blen)
98145
accum = pyrtl.Register(alen + blen)
99-
done = areg == 0 # Multiplication is finished when a becomes 0
146+
done = pyrtl.WireVector(bitwidth=1)
100147
if (shifts > alen) or (shifts > blen):
101148
msg = (
102149
"shift is larger than one or both of the parameters A or B, please choose "
@@ -112,12 +159,15 @@ def complex_mult(
112159
breg.next |= B
113160
accum.next |= 0
114161

115-
with ~done: # don't run when there's no work to do
162+
with areg != 0: # don't run when there's no work to do
116163
# "Multiply" shifted breg by LSB of areg by cond. adding
117164
areg.next |= pyrtl.shift_right_logical(areg, shifts)
118165
breg.next |= pyrtl.shift_left_logical(breg, shifts)
119166
accum.next |= accum + _one_cycle_mult(areg, breg, shifts)
120167

168+
with pyrtl.otherwise:
169+
done |= True
170+
121171
return accum, done
122172

123173

@@ -158,6 +208,24 @@ def tree_multiplier(
158208
159209
Delay is `O(log(N))`, while area is `O(N^2)`.
160210
211+
.. doctest only::
212+
213+
>>> import pyrtl
214+
>>> pyrtl.reset_working_block()
215+
216+
Example::
217+
218+
>>> a = pyrtl.Input(name="a", bitwidth=4)
219+
>>> b = pyrtl.Input(name="b", bitwidth=4)
220+
>>> output = pyrtl.Output(name="output")
221+
222+
>>> output <<= pyrtl.rtllib.multipliers.tree_multiplier(a, b)
223+
224+
>>> sim = pyrtl.Simulation()
225+
>>> sim.step({"a": 2, "b": 3})
226+
>>> sim.inspect("output")
227+
6
228+
161229
:param A: Input wire for the multiplication.
162230
:param B: Input wire for the multiplication.
163231
:param reducer: Reducing the tree with a :func:`~.adders.wallace_reducer` or a
@@ -190,7 +258,26 @@ def tree_multiplier(
190258
def signed_tree_multiplier(
191259
A, B, reducer=adders.wallace_reducer, adder_func=adders.kogge_stone
192260
):
193-
"""Same as :func:`tree_multiplier`, but uses two's-complement signed integers."""
261+
"""Same as :func:`tree_multiplier`, but uses two's-complement signed integers.
262+
263+
.. doctest only::
264+
265+
>>> import pyrtl
266+
>>> pyrtl.reset_working_block()
267+
268+
Example::
269+
270+
>>> a = pyrtl.Input(name="a", bitwidth=4)
271+
>>> b = pyrtl.Input(name="b", bitwidth=4)
272+
>>> output = pyrtl.Output(name="output")
273+
274+
>>> output <<= pyrtl.rtllib.multipliers.signed_tree_multiplier(a, b)
275+
276+
>>> sim = pyrtl.Simulation()
277+
>>> sim.step({"a": -2, "b": 3})
278+
>>> pyrtl.val_to_signed_integer(sim.inspect("output"), bitwidth=output.bitwidth)
279+
-6
280+
"""
194281
if len(A) == 1 or len(B) == 1:
195282
msg = "sign bit required, one or both wires too small"
196283
raise pyrtl.PyrtlError(msg)
@@ -227,6 +314,27 @@ def fused_multiply_adder(
227314
these operations, rather than doing them separately, one reduces both the area and
228315
the timing delay of the circuit.
229316
317+
.. doctest only::
318+
319+
>>> import pyrtl
320+
>>> pyrtl.reset_working_block()
321+
322+
Example::
323+
324+
>>> a = pyrtl.Input(name="a", bitwidth=4)
325+
>>> b = pyrtl.Input(name="b", bitwidth=4)
326+
>>> c = pyrtl.Input(name="c", bitwidth=4)
327+
>>> output = pyrtl.Output(name="output")
328+
329+
>>> output <<= pyrtl.rtllib.multipliers.fused_multiply_adder(a, b, c)
330+
331+
>>> sim = pyrtl.Simulation()
332+
>>> sim.step({"a": 2, "b": 3, "c": 4})
333+
>>> pyrtl.val_to_signed_integer(sim.inspect("output"), bitwidth=output.bitwidth)
334+
10
335+
>>> 2 * 3 + 4
336+
10
337+
230338
:param mult_A: Input wire for the multiplication.
231339
:param mult_B: Input wire for the multiplication.
232340
:param add: Input wire for the addition.
@@ -251,14 +359,35 @@ def generalized_fma(
251359
signed: bool = False, # noqa: ARG001
252360
reducer: Callable = adders.wallace_reducer,
253361
adder_func: Callable = adders.kogge_stone,
254-
):
362+
) -> pyrtl.WireVector:
255363
"""Generated an optimized fused multiply adder.
256364
257365
A generalized FMA unit that multiplies each pair of numbers in ``mult_pairs``, then
258366
adds up the resulting products and all the values of the ``add_wires``. This is
259367
faster than multiplying and adding separately because you avoid unnecessary adder
260368
structures for intermediate representations.
261369
370+
.. doctest only::
371+
372+
>>> import pyrtl
373+
>>> pyrtl.reset_working_block()
374+
375+
Example::
376+
377+
>>> mult_pairs = [(pyrtl.Const(2), pyrtl.Const(3)),
378+
... (pyrtl.Const(4), pyrtl.Const(5))]
379+
>>> add_wires = [pyrtl.Const(6), pyrtl.Const(7)]
380+
>>> output = pyrtl.Output(name="output")
381+
382+
>>> output <<= pyrtl.rtllib.multipliers.generalized_fma(mult_pairs, add_wires)
383+
384+
>>> sim = pyrtl.Simulation()
385+
>>> sim.step()
386+
>>> sim.inspect("output")
387+
39
388+
>>> 2 * 3 + 4 * 5 + 6 + 7
389+
39
390+
262391
:param mult_pairs: Either ``None`` (if there are no pairs to multiply) or a list of
263392
pairs of wires to multiply together: ``[(mult1_1, mult1_2), ...]``
264393
:param add_wires: Either ``None`` (if there are no individual items to add other

tests/rtllib/test_multipliers.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import doctest
12
import random
23
import unittest
34

@@ -6,6 +7,15 @@
67
from pyrtl.rtllib import adders, multipliers
78

89

10+
class TestDocTests(unittest.TestCase):
11+
"""Test documentation examples."""
12+
13+
def test_doctests(self):
14+
failures, tests = doctest.testmod(m=pyrtl.rtllib.multipliers)
15+
self.assertGreater(tests, 0)
16+
self.assertEqual(failures, 0)
17+
18+
919
class TestSimpleMult(unittest.TestCase):
1020
def setUp(self):
1121
pyrtl.reset_working_block()
@@ -45,12 +55,11 @@ def mult_t_base(self, len_a, len_b):
4555
for x_val, y_val in zip(xvals, yvals):
4656
sim = pyrtl.Simulation()
4757
sim.step({a: x_val, b: y_val, reset: 1})
48-
for _cycle in range(len(a) + 1):
58+
while not sim.inspect("done"):
4959
sim.step({a: 0, b: 0, reset: 0})
5060

5161
# Extracting the values and verifying correctness
5262
mult_results.append(sim.inspect("product"))
53-
self.assertEqual(sim.inspect("done"), 1)
5463
self.assertEqual(mult_results, true_result)
5564

5665

@@ -100,16 +109,11 @@ def mult_t_base(self, len_a, len_b, shifts):
100109
for x_val, y_val in zip(xvals, yvals):
101110
sim = pyrtl.Simulation()
102111
sim.step({a: x_val, b: y_val, reset: 1})
103-
if shifts <= len_a:
104-
length = len_a // shifts + (1 if len_a % shifts == 0 else 2)
105-
else:
106-
length = len_a + 1
107-
for _cycle in range(length):
112+
while not sim.inspect("done"):
108113
sim.step({a: 0, b: 0, reset: 0})
109114

110115
# Extracting the values and verifying correctness
111116
mult_results.append(sim.inspect("product"))
112-
self.assertEqual(sim.inspect("done"), 1)
113117
self.assertEqual(mult_results, true_result)
114118

115119

0 commit comments

Comments
 (0)