Skip to content

Commit e00bb40

Browse files
authored
Merge pull request #346 from pllab/docs-and-tests
Adding documentation and some tests for the passes and transforms
2 parents ad2c54e + 69f0f5d commit e00bb40

File tree

3 files changed

+182
-9
lines changed

3 files changed

+182
-9
lines changed

pyrtl/passes.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,7 @@ def replace_net_with_wire(new_wire):
216216

217217

218218
def common_subexp_elimination(block=None, abs_thresh=1, percent_thresh=0):
219-
"""
220-
Common Subexpression Elimination for PyRTL blocks
219+
""" Common Subexpression Elimination for PyRTL blocks.
221220
222221
:param block: the block to run the subexpression elimination on
223222
:param abs_thresh: absolute threshold for stopping optimization
@@ -235,6 +234,18 @@ def common_subexp_elimination(block=None, abs_thresh=1, percent_thresh=0):
235234

236235

237236
def _find_common_subexps(block):
237+
""" Finds nets that can be considered the same based on op type, op param, and arguments.
238+
239+
:param block: Block to operate over
240+
:return dict[LogicNet, [LogicNet]]: mapping from a logic net (with a placehold dest)
241+
representing the common subexp, to a list of nets matching that common subexp that
242+
can be replaced with the single common subexp.
243+
244+
Nets are the "same" if 1) their op types are the same, 2) their op_params are
245+
the same (e.g. same memory if a memory-related op), and 3) their arguments are
246+
the same (same constant value and bitwidth for const wires, otherwise same wire
247+
object). The destination wire for a net is not considered.
248+
"""
238249
net_table = {} # {net (without dest) : [net, ...]
239250
t = tuple() # just a placeholder
240251
const_dict = {}
@@ -252,6 +263,11 @@ def _find_common_subexps(block):
252263

253264

254265
def _const_to_int(wire, const_dict):
266+
""" Return a repr a Const (a tuple composed of width and value) for comparison with an 'is'.
267+
268+
If the wire is not a Const, just return the wire itself; comparison will be
269+
done on the identity of the wire object instead.
270+
"""
255271
if isinstance(wire, Const):
256272
# a very bad hack to make sure two consts will compare
257273
# correctly with an 'is'
@@ -268,6 +284,12 @@ def _const_to_int(wire, const_dict):
268284

269285

270286
def _replace_subexps(block, net_table):
287+
""" Removes unnecessary nets, connecting the common net's dest wire to unnecessary net's dest.
288+
289+
:param block: The block to operate over.
290+
:param net_table: A mapping from common subexpression (a net) to a list of nets
291+
that can be replaced with that common net.
292+
"""
271293
wire_map = {}
272294
unnecessary_nets = []
273295
for nets in net_table.values():
@@ -282,6 +304,16 @@ def _has_normal_dest_wire(net):
282304

283305

284306
def _process_nets_to_discard(nets, wire_map, unnecessary_nets):
307+
""" Helper for tracking how a group of related nets should be replaced with a common one.
308+
309+
:param nets: List of nets that are considered equal and which should
310+
be replaced by a single common net.
311+
:param wire_map: Dict that will be updated with a mapping from every
312+
old destination wire that needs to be removed, to the new destination
313+
wire with which it should be replaced.
314+
:param unnecessary_nets: List of nets that are to be discarded.
315+
316+
"""
285317
if len(nets) == 1:
286318
return # also deals with nets with no dest wires
287319
nets_to_consider = list(filter(_has_normal_dest_wire, nets))
@@ -297,7 +329,9 @@ def _process_nets_to_discard(nets, wire_map, unnecessary_nets):
297329

298330

299331
def _remove_unlistened_nets(block):
300-
""" Removes all nets that are not connected to an output wirevector
332+
""" Removes all nets that are not connected to an output wirevector.
333+
334+
:param block: The block to operate over.
301335
"""
302336

303337
listened_nets = set()
@@ -326,7 +360,12 @@ def add_to_listened(net):
326360

327361

328362
def _remove_unused_wires(block, keep_inputs=True):
329-
""" Removes all unconnected wires from a block"""
363+
""" Removes all unconnected wires from a block's wirevector_set.
364+
365+
:param block: The block to operate over.
366+
:param keep_inputs: If True, retain any Input wires that are not connected
367+
to any net.
368+
"""
330369
valid_wires = set()
331370
for logic_net in block.logic:
332371
valid_wires.update(logic_net.args, logic_net.dests)

pyrtl/transform.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
def net_transform(transform_func, block=None, **kwargs):
28-
""" Maps nets to new sets of nets according to a custom function
28+
""" Maps nets to new sets of nets according to a custom function.
2929
3030
:param transform_func:
3131
Function signature: func(orig_net (logicnet)) -> keep_orig_net (bool)
@@ -44,7 +44,7 @@ def net_transform(transform_func, block=None, **kwargs):
4444

4545

4646
def all_nets(transform_func):
47-
"""Decorator that wraps a net transform function"""
47+
""" Decorator that wraps a net transform function. """
4848
@functools.wraps(transform_func)
4949
def t_res(**kwargs):
5050
net_transform(transform_func, **kwargs)
@@ -107,10 +107,10 @@ def replace_wire(orig_wire, new_src, new_dst, block=None):
107107

108108

109109
def replace_wires(wire_map, block=None):
110-
""" Quickly replace all wires in a block
110+
""" Quickly replace all wires in a block.
111111
112-
:param {old_wire: new_wire} wire_map: mapping of old wires to
113-
new wires
112+
:param {old_wire: new_wire} wire_map: mapping of old wires to new wires
113+
:param block: block to operate over (defaults to working block)
114114
"""
115115
block = working_block(block)
116116
src_nets, dst_nets = block.net_connections(include_virtual_nodes=False)
@@ -119,6 +119,80 @@ def replace_wires(wire_map, block=None):
119119

120120

121121
def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=None):
122+
"""
123+
Replace orig_wire with new_src and/or new_dst. The net that orig_wire originates from
124+
(its source net) will now feed into new_src as its destination, and the nets that
125+
orig_wire went to (its destination nets) will be fed from new_dst as
126+
their respective arguments.
127+
128+
:param WireVector orig_wire: Wire to be replaced
129+
:param WireVector new_src: Wire to replace orig_wire, anywhere orig_wire is the
130+
destination of a net. Ignored if orig_wire equals new_src.
131+
:param WireVector new_dst: Wire to replace orig_wire, anywhere orig_wire is an
132+
argument of a net. Ignored if orig_wire equals new_dst.
133+
:param {WireVector: LogicNet} src_nets: Maps a wire to the net where it is a dest
134+
:param {WireVector: List[LogicNet]} dst_nets: Maps a wire to list of nets where it is an arg
135+
:param Block block: The block on which to operate (defaults to working block)
136+
137+
new_src will now originate from orig_wire's source net (meaning new_src will be that net's
138+
destination). new_dst will be now
139+
140+
This *updates* the src_nets and dst_nets maps that are passed in, such that:
141+
142+
```
143+
old_src_net = src_nets[orig_wire]
144+
src_nets[new_src] = old_src_net (where old_src_net.dests = (new_src,))
145+
```
146+
and
147+
```
148+
old_dst_nets = dst_nets[orig_wire]
149+
dst_nets[new_dst] = [old_dst_net (where old_dst_net.args replaces orig_wire with new_dst) foreach old_dst_net] # noqa
150+
```
151+
152+
This also removes and/or adds nets to the block's logic set.
153+
154+
For example, given the graph on left, `replace_wire_fast(w1, w4, w1, ...)` produces on right:
155+
156+
```
157+
a b c d a b c d
158+
| | | | | | | |
159+
net net net net
160+
| | | |
161+
w1 w2 ==> produces ==> w4 w1 w2
162+
| | | |
163+
net net
164+
| |
165+
w3 w3
166+
```
167+
168+
And given the graph on the left, `replace_wire_fast(w1, w1, w4, ...)` produces on the right:
169+
```
170+
a b c d a b c d
171+
| | | | | | | |
172+
net net net net
173+
| | | |
174+
w1 w2 ==> produces ==> w1 w4 w2
175+
| | | |
176+
net net
177+
| |
178+
w3 w3
179+
```
180+
181+
Calling `replace_wire_fast(w1, w4, w4, ...)`, then, fully replaces w1 with w3 in both
182+
its argument and dest positions:
183+
184+
```
185+
a b c d a b c d
186+
| | | | | | | |
187+
net net net net
188+
| | | |
189+
w1 w2 ==> produces ==> w4 w2
190+
| | | |
191+
net net
192+
| |
193+
w3 w3
194+
```
195+
"""
122196
def remove_net(net_):
123197
for arg in set(net_.args):
124198
dst_nets[arg].remove(net_)

tests/test_transform.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,66 @@ def test_wire_used_in_multiple_places(self):
197197
self.assertNotIn(old_wire, block.wirevector_set)
198198
block.sanity_check()
199199

200+
def test_replace_only_src_wire(self):
201+
a, b, c, d = pyrtl.input_list('a/1 b/1 c/1 d/1')
202+
w1 = a & b
203+
w1.name = 'w1'
204+
w2 = c | d
205+
w2.name = 'w2'
206+
w3 = w1 ^ w2
207+
w3.name = 'w3'
208+
o = pyrtl.Output(1, 'o')
209+
o <<= w3
210+
211+
w4 = pyrtl.WireVector(1, 'w4')
212+
src_nets, dst_nets = pyrtl.working_block().net_connections()
213+
214+
w1_src_net = src_nets[w1]
215+
w1_dst_net = dst_nets[w1][0]
216+
self.assertEqual(w1_src_net.args, (a, b))
217+
self.assertEqual(w1_src_net.dests, (w1,))
218+
self.assertEqual(w1_dst_net.args, (w1, w2))
219+
self.assertEqual(w1_dst_net.dests, (w3,))
220+
self.assertNotIn(w4, src_nets)
221+
222+
pyrtl.transform.replace_wire_fast(w1, w4, w1, src_nets, dst_nets)
223+
224+
self.assertNotIn(w1, src_nets) # The maps have been updated...
225+
self.assertEqual(dst_nets[w1], [w1_dst_net])
226+
w4_src_net = src_nets[w4] # ...but the net can't be, so new updated versions were created
227+
self.assertEqual(w4_src_net.args, w1_src_net.args)
228+
self.assertEqual(w4_src_net.dests, (w4,))
229+
230+
def test_replace_only_dst_wire(self):
231+
a, b, c, d = pyrtl.input_list('a/1 b/1 c/1 d/1')
232+
w1 = a & b
233+
w1.name = 'w1'
234+
w2 = c | d
235+
w2.name = 'w2'
236+
w3 = w1 ^ w2
237+
w3.name = 'w3'
238+
o = pyrtl.Output(1, 'o')
239+
o <<= w3
240+
241+
w4 = pyrtl.WireVector(1, 'w4')
242+
src_nets, dst_nets = pyrtl.working_block().net_connections()
243+
244+
w1_src_net = src_nets[w1]
245+
w1_dst_net = dst_nets[w1][0]
246+
self.assertEqual(w1_src_net.args, (a, b))
247+
self.assertEqual(w1_src_net.dests, (w1,))
248+
self.assertEqual(w1_dst_net.args, (w1, w2))
249+
self.assertEqual(w1_dst_net.dests, (w3,))
250+
self.assertNotIn(w4, src_nets)
251+
252+
pyrtl.transform.replace_wire_fast(w1, w1, w4, src_nets, dst_nets)
253+
254+
self.assertNotIn(w1, dst_nets) # The maps have been updated...
255+
self.assertEqual(src_nets[w1], w1_src_net)
256+
w4_dst_net = dst_nets[w4][0] # ...but the net can't be, so new versions were created
257+
self.assertEqual(w4_dst_net.args, (w4, w2))
258+
self.assertEqual(w4_dst_net.dests, w1_dst_net.dests)
259+
200260

201261
# this code needs mocking from python 3's unittests to work
202262
"""

0 commit comments

Comments
 (0)