Skip to content

Commit afd91a1

Browse files
committed
remove replace_wire, use replace_wire_fast instead; add test and more documentation
1 parent e00bb40 commit afd91a1

File tree

3 files changed

+69
-42
lines changed

3 files changed

+69
-42
lines changed

pyrtl/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,9 @@
110110
from .passes import one_bit_selects
111111
from .passes import two_way_concat
112112

113-
from .transform import net_transform, wire_transform, replace_wire, copy_block, clone_wire
113+
from .transform import net_transform
114+
from .transform import wire_transform
115+
from .transform import copy_block
116+
from .transform import clone_wire
117+
from .transform import replace_wires
118+
from .transform import replace_wire_fast

pyrtl/transform.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -63,51 +63,27 @@ def wire_transform(transform_func, select_types=WireVector,
6363
:param select_types: Type or Tuple of types of WireVectors to replace
6464
:param exclude_types: Type or Tuple of types of WireVectors to exclude from replacement
6565
:param block: The Block to replace wires on
66+
67+
Note that if both new_src and new_dst don't equal orig_wire, orig_wire will
68+
be removed from the block entirely.
6669
"""
6770
block = working_block(block)
71+
src_nets, dst_nets = block.net_connections(include_virtual_nodes=False)
6872
for orig_wire in block.wirevector_subset(select_types, exclude_types):
6973
new_src, new_dst = transform_func(orig_wire)
70-
replace_wire(orig_wire, new_src, new_dst, block)
74+
replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block)
7175

7276

7377
def all_wires(transform_func):
74-
"""Decorator that wraps a wire transform function"""
78+
""" Decorator that wraps a wire transform function. """
7579
@functools.wraps(transform_func)
7680
def t_res(**kwargs):
7781
wire_transform(transform_func, **kwargs)
7882
return t_res
7983

8084

81-
def replace_wire(orig_wire, new_src, new_dst, block=None):
82-
block = working_block(block)
83-
if new_src is not orig_wire:
84-
# don't need to add the new_src and new_dst because they were made added at creation
85-
for net in block.logic:
86-
for wire in net.dests: # problem is that tuples use the == operator when using 'in'
87-
if wire is orig_wire:
88-
new_net = LogicNet(
89-
op=net.op, op_param=net.op_param, args=net.args,
90-
dests=tuple(new_src if w is orig_wire else w for w in net.dests))
91-
block.add_net(new_net)
92-
block.logic.remove(net)
93-
break
94-
95-
if new_dst is not orig_wire:
96-
for net in block.logic:
97-
for wire in set(net.args):
98-
if wire is orig_wire:
99-
new_net = LogicNet(
100-
op=net.op, op_param=net.op_param, dests=net.dests,
101-
args=tuple(new_src if w is orig_wire else w for w in net.args))
102-
block.add_net(new_net)
103-
block.logic.remove(net)
104-
105-
if new_dst is not orig_wire and new_src is not orig_wire:
106-
block.remove_wirevector(orig_wire)
107-
108-
10985
def replace_wires(wire_map, block=None):
110-
""" Quickly replace all wires in a block.
86+
""" Replace all wires in a block.
11187
11288
:param {old_wire: new_wire} wire_map: mapping of old wires to new wires
11389
:param block: block to operate over (defaults to working block)
@@ -119,11 +95,7 @@ def replace_wires(wire_map, block=None):
11995

12096

12197
def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=None):
122-
"""
123-
Replace orig_wire with new_src and/or new_dst. The net that orig_wire originates from
124-
(its source net) will now feed into new_src as its destination, and the nets that
125-
orig_wire went to (its destination nets) will be fed from new_dst as
126-
their respective arguments.
98+
""" Replace orig_wire with new_src and/or new_dst.
12799
128100
:param WireVector orig_wire: Wire to be replaced
129101
:param WireVector new_src: Wire to replace orig_wire, anywhere orig_wire is the
@@ -134,10 +106,12 @@ def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=Non
134106
:param {WireVector: List[LogicNet]} dst_nets: Maps a wire to list of nets where it is an arg
135107
:param Block block: The block on which to operate (defaults to working block)
136108
137-
new_src will now originate from orig_wire's source net (meaning new_src will be that net's
138-
destination). new_dst will be now
109+
The net that orig_wire originates from (its source net) will use new_src as its
110+
destination wire. The nets that orig_wire went to (its destination nets) will now
111+
have new_dst as one of their argument wires instead.
139112
140-
This *updates* the src_nets and dst_nets maps that are passed in, such that:
113+
This removes and/or adds nets to the block's logic set. This also *updates* the
114+
src_nets and dst_nets maps that are passed in, such that the following hold:
141115
142116
```
143117
old_src_net = src_nets[orig_wire]
@@ -149,8 +123,6 @@ def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=Non
149123
dst_nets[new_dst] = [old_dst_net (where old_dst_net.args replaces orig_wire with new_dst) foreach old_dst_net] # noqa
150124
```
151125
152-
This also removes and/or adds nets to the block's logic set.
153-
154126
For example, given the graph on left, `replace_wire_fast(w1, w4, w1, ...)` produces on right:
155127
156128
```

tests/test_transform.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,56 @@ def test_randomly_replace(self):
7373
self.assertIsNot(arg, b)
7474
self.assertIsNot(new_and_net.dests[0], o)
7575

76+
def test_replace_input(self):
77+
78+
def f(wire):
79+
if wire.name == 'a':
80+
w = pyrtl.clone_wire(wire, 'w2')
81+
else:
82+
w = pyrtl.clone_wire(wire, 'w3')
83+
return wire, w
84+
85+
a, b = pyrtl.input_list('a/1 b/1')
86+
w1 = a & b
87+
o = pyrtl.Output(1, 'o')
88+
o <<= w1
89+
90+
src_nets, dst_nets = pyrtl.working_block().net_connections()
91+
self.assertEqual(src_nets[w1], pyrtl.LogicNet('&', None, (a, b), (w1,)))
92+
self.assertIn(a, dst_nets)
93+
self.assertIn(b, dst_nets)
94+
95+
transform.wire_transform(f, select_types=pyrtl.Input, exclude_types=tuple())
96+
97+
w2 = pyrtl.working_block().get_wirevector_by_name('w2')
98+
w3 = pyrtl.working_block().get_wirevector_by_name('w3')
99+
src_nets, dst_nets = pyrtl.working_block().net_connections()
100+
self.assertEqual(src_nets[w1], pyrtl.LogicNet('&', None, (w2, w3), (w1,)))
101+
self.assertNotIn(a, dst_nets)
102+
self.assertNotIn(b, dst_nets)
103+
104+
def test_replace_output(self):
105+
106+
def f(wire):
107+
w = pyrtl.clone_wire(wire, 'w2')
108+
return w, wire
109+
110+
a, b = pyrtl.input_list('a/1 b/1')
111+
w1 = a & b
112+
o = pyrtl.Output(1, 'o')
113+
o <<= w1
114+
115+
src_nets, dst_nets = pyrtl.working_block().net_connections()
116+
self.assertEqual(dst_nets[w1], [pyrtl.LogicNet('w', None, (w1,), (o,))])
117+
self.assertIn(o, src_nets)
118+
119+
transform.wire_transform(f, select_types=pyrtl.Output, exclude_types=tuple())
120+
121+
w2 = pyrtl.working_block().get_wirevector_by_name('w2')
122+
src_nets, dst_nets = pyrtl.working_block().net_connections()
123+
self.assertEqual(dst_nets[w1], [pyrtl.LogicNet('w', None, (w1,), (w2,))])
124+
self.assertNotIn(o, src_nets)
125+
76126

77127
class TestCopyBlock(NetWireNumTestCases, WireMemoryNameTestCases):
78128
def num_memories(self, mems_expected, block):

0 commit comments

Comments
 (0)