Skip to content

Commit 072bf89

Browse files
authored
Merge pull request #347 from pllab/transform-cleanup
Remove `replace_wire`, use `replace_wire_fast` instead; add tests and more documentation
2 parents 3595dda + afd91a1 commit 072bf89

File tree

3 files changed

+69
-42
lines changed

3 files changed

+69
-42
lines changed

pyrtl/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,9 @@
110110
from .passes import one_bit_selects
111111
from .passes import two_way_concat
112112

113-
from .transform import net_transform, wire_transform, replace_wire, copy_block, clone_wire
113+
from .transform import net_transform
114+
from .transform import wire_transform
115+
from .transform import copy_block
116+
from .transform import clone_wire
117+
from .transform import replace_wires
118+
from .transform import replace_wire_fast

pyrtl/transform.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -64,51 +64,27 @@ def wire_transform(transform_func, select_types=WireVector,
6464
:param select_types: Type or Tuple of types of WireVectors to replace
6565
:param exclude_types: Type or Tuple of types of WireVectors to exclude from replacement
6666
:param block: The Block to replace wires on
67+
68+
Note that if both new_src and new_dst don't equal orig_wire, orig_wire will
69+
be removed from the block entirely.
6770
"""
6871
block = working_block(block)
72+
src_nets, dst_nets = block.net_connections(include_virtual_nodes=False)
6973
for orig_wire in block.wirevector_subset(select_types, exclude_types):
7074
new_src, new_dst = transform_func(orig_wire)
71-
replace_wire(orig_wire, new_src, new_dst, block)
75+
replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block)
7276

7377

7478
def all_wires(transform_func):
75-
"""Decorator that wraps a wire transform function"""
79+
""" Decorator that wraps a wire transform function. """
7680
@functools.wraps(transform_func)
7781
def t_res(**kwargs):
7882
wire_transform(transform_func, **kwargs)
7983
return t_res
8084

8185

82-
def replace_wire(orig_wire, new_src, new_dst, block=None):
83-
block = working_block(block)
84-
if new_src is not orig_wire:
85-
# don't need to add the new_src and new_dst because they were made added at creation
86-
for net in block.logic:
87-
for wire in net.dests: # problem is that tuples use the == operator when using 'in'
88-
if wire is orig_wire:
89-
new_net = LogicNet(
90-
op=net.op, op_param=net.op_param, args=net.args,
91-
dests=tuple(new_src if w is orig_wire else w for w in net.dests))
92-
block.add_net(new_net)
93-
block.logic.remove(net)
94-
break
95-
96-
if new_dst is not orig_wire:
97-
for net in block.logic:
98-
for wire in set(net.args):
99-
if wire is orig_wire:
100-
new_net = LogicNet(
101-
op=net.op, op_param=net.op_param, dests=net.dests,
102-
args=tuple(new_src if w is orig_wire else w for w in net.args))
103-
block.add_net(new_net)
104-
block.logic.remove(net)
105-
106-
if new_dst is not orig_wire and new_src is not orig_wire:
107-
block.remove_wirevector(orig_wire)
108-
109-
11086
def replace_wires(wire_map, block=None):
111-
""" Quickly replace all wires in a block.
87+
""" Replace all wires in a block.
11288
11389
:param {old_wire: new_wire} wire_map: mapping of old wires to new wires
11490
:param block: block to operate over (defaults to working block)
@@ -120,11 +96,7 @@ def replace_wires(wire_map, block=None):
12096

12197

12298
def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=None):
123-
"""
124-
Replace orig_wire with new_src and/or new_dst. The net that orig_wire originates from
125-
(its source net) will now feed into new_src as its destination, and the nets that
126-
orig_wire went to (its destination nets) will be fed from new_dst as
127-
their respective arguments.
99+
""" Replace orig_wire with new_src and/or new_dst.
128100
129101
:param WireVector orig_wire: Wire to be replaced
130102
:param WireVector new_src: Wire to replace orig_wire, anywhere orig_wire is the
@@ -135,10 +107,12 @@ def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=Non
135107
:param {WireVector: List[LogicNet]} dst_nets: Maps a wire to list of nets where it is an arg
136108
:param Block block: The block on which to operate (defaults to working block)
137109
138-
new_src will now originate from orig_wire's source net (meaning new_src will be that net's
139-
destination). new_dst will be now
110+
The net that orig_wire originates from (its source net) will use new_src as its
111+
destination wire. The nets that orig_wire went to (its destination nets) will now
112+
have new_dst as one of their argument wires instead.
140113
141-
This *updates* the src_nets and dst_nets maps that are passed in, such that:
114+
This removes and/or adds nets to the block's logic set. This also *updates* the
115+
src_nets and dst_nets maps that are passed in, such that the following hold:
142116
143117
```
144118
old_src_net = src_nets[orig_wire]
@@ -150,8 +124,6 @@ def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=Non
150124
dst_nets[new_dst] = [old_dst_net (where old_dst_net.args replaces orig_wire with new_dst) foreach old_dst_net] # noqa
151125
```
152126
153-
This also removes and/or adds nets to the block's logic set.
154-
155127
For example, given the graph on left, `replace_wire_fast(w1, w4, w1, ...)` produces on right:
156128
157129
```

tests/test_transform.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,56 @@ def test_randomly_replace(self):
7575
self.assertIsNot(arg, b)
7676
self.assertIsNot(new_and_net.dests[0], o)
7777

78+
def test_replace_input(self):
79+
80+
def f(wire):
81+
if wire.name == 'a':
82+
w = pyrtl.clone_wire(wire, 'w2')
83+
else:
84+
w = pyrtl.clone_wire(wire, 'w3')
85+
return wire, w
86+
87+
a, b = pyrtl.input_list('a/1 b/1')
88+
w1 = a & b
89+
o = pyrtl.Output(1, 'o')
90+
o <<= w1
91+
92+
src_nets, dst_nets = pyrtl.working_block().net_connections()
93+
self.assertEqual(src_nets[w1], pyrtl.LogicNet('&', None, (a, b), (w1,)))
94+
self.assertIn(a, dst_nets)
95+
self.assertIn(b, dst_nets)
96+
97+
transform.wire_transform(f, select_types=pyrtl.Input, exclude_types=tuple())
98+
99+
w2 = pyrtl.working_block().get_wirevector_by_name('w2')
100+
w3 = pyrtl.working_block().get_wirevector_by_name('w3')
101+
src_nets, dst_nets = pyrtl.working_block().net_connections()
102+
self.assertEqual(src_nets[w1], pyrtl.LogicNet('&', None, (w2, w3), (w1,)))
103+
self.assertNotIn(a, dst_nets)
104+
self.assertNotIn(b, dst_nets)
105+
106+
def test_replace_output(self):
107+
108+
def f(wire):
109+
w = pyrtl.clone_wire(wire, 'w2')
110+
return w, wire
111+
112+
a, b = pyrtl.input_list('a/1 b/1')
113+
w1 = a & b
114+
o = pyrtl.Output(1, 'o')
115+
o <<= w1
116+
117+
src_nets, dst_nets = pyrtl.working_block().net_connections()
118+
self.assertEqual(dst_nets[w1], [pyrtl.LogicNet('w', None, (w1,), (o,))])
119+
self.assertIn(o, src_nets)
120+
121+
transform.wire_transform(f, select_types=pyrtl.Output, exclude_types=tuple())
122+
123+
w2 = pyrtl.working_block().get_wirevector_by_name('w2')
124+
src_nets, dst_nets = pyrtl.working_block().net_connections()
125+
self.assertEqual(dst_nets[w1], [pyrtl.LogicNet('w', None, (w1,), (w2,))])
126+
self.assertNotIn(o, src_nets)
127+
78128

79129
class TestCopyBlock(NetWireNumTestCases, WireMemoryNameTestCases):
80130
def num_memories(self, mems_expected, block):

0 commit comments

Comments
 (0)