Skip to content

Commit 91906d2

Browse files
committed
Updating clone_wire to be safe when used on same block.
1 parent e00bb40 commit 91906d2

File tree

2 files changed

+74
-8
lines changed

2 files changed

+74
-8
lines changed

pyrtl/transform.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import functools
2222

2323
from .core import set_working_block, LogicNet, working_block
24-
from .wire import Const, Input, Output, WireVector, Register
24+
from .wire import Const, Input, Output, WireVector, Register, next_tempvar_name
2525

2626

2727
def net_transform(transform_func, block=None, **kwargs):
@@ -242,17 +242,18 @@ def clone_wire(old_wire, name=None):
242242
:param old_wire: The wire to clone
243243
:param name: A name for the new wire
244244
245-
Note that this function is mainly intended to be used when the
246-
two wires are from different blocks. Making two wires with the
247-
same name in the same block is not allowed.
245+
Naming the newly cloned wire the same as another existing wire in the same
246+
block will cause the other wire to be given a new internally created name. This
247+
function is mainly intended to be used when the two wires are from different blocks.
248248
"""
249+
name = old_wire.name if name is None else name
250+
if name in working_block().wirevector_by_name and (working_block() is old_wire._block):
251+
w = working_block().wirevector_by_name[name]
252+
w.name = next_tempvar_name()
253+
249254
if isinstance(old_wire, Const):
250-
if name is None:
251-
return Const(old_wire.val, old_wire.bitwidth, name=old_wire.name)
252255
return Const(old_wire.val, old_wire.bitwidth, name=name)
253256
else:
254-
if name is None:
255-
return old_wire.__class__(old_wire.bitwidth, name=old_wire.name)
256257
return old_wire.__class__(old_wire.bitwidth, name=name)
257258

258259

tests/test_transform.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pyrtl.core import set_working_block
12
import unittest
23
import pyrtl
34
from pyrtl import transform
@@ -257,6 +258,70 @@ def test_replace_only_dst_wire(self):
257258
self.assertEqual(w4_dst_net.args, (w4, w2))
258259
self.assertEqual(w4_dst_net.dests, w1_dst_net.dests)
259260

261+
class TestCloning(unittest.TestCase):
262+
def setUp(self):
263+
pyrtl.reset_working_block()
264+
265+
def test_clone_wire_no_or_same_name_same_block(self):
266+
for clone_name in (None, 'a'):
267+
a = pyrtl.WireVector(1, 'a')
268+
self.assertEqual(a.name, 'a')
269+
self.assertEqual(pyrtl.working_block().wirevector_set, {a})
270+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
271+
272+
w = pyrtl.clone_wire(a, name=clone_name)
273+
self.assertTrue(a.name.startswith("tmp"))
274+
self.assertEqual(w.name, 'a')
275+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], w)
276+
self.assertEqual(pyrtl.working_block().wirevector_set, {a, w})
277+
278+
w.name = 'w'
279+
self.assertEqual(w.name, 'w')
280+
self.assertIs(pyrtl.working_block().wirevector_by_name['w'], w)
281+
282+
pyrtl.working_block().remove_wirevector(a)
283+
self.assertEqual(pyrtl.working_block().wirevector_set, {w})
284+
pyrtl.reset_working_block()
285+
286+
def test_clone_wire_no_or_same_name_different_block(self):
287+
for clone_name in (None, 'a'):
288+
a = pyrtl.WireVector(1, 'a')
289+
b = pyrtl.Block()
290+
with pyrtl.set_working_block(b):
291+
w = pyrtl.clone_wire(a, name=clone_name)
292+
293+
self.assertEqual(a.name, 'a')
294+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
295+
self.assertEqual(pyrtl.working_block().wirevector_set, {a})
296+
297+
self.assertEqual(w.name, 'a')
298+
self.assertIs(b.wirevector_by_name['a'], w)
299+
self.assertEqual(b.wirevector_set, {w})
300+
pyrtl.reset_working_block()
301+
302+
def test_clone_wire_with_different_name_same_block(self):
303+
a = pyrtl.WireVector(1, 'a')
304+
w = pyrtl.clone_wire(a, 'w')
305+
self.assertEqual(a.name, 'a')
306+
self.assertEqual(w.name, 'w')
307+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
308+
self.assertEqual(pyrtl.working_block().wirevector_set, {a, w})
309+
310+
pyrtl.working_block().remove_wirevector(a)
311+
self.assertEqual(pyrtl.working_block().wirevector_set, {w})
312+
313+
def test_clone_wire_with_different_name_different_block(self):
314+
a = pyrtl.WireVector(1, 'a')
315+
b = pyrtl.Block()
316+
with set_working_block(b):
317+
w = pyrtl.clone_wire(a, 'w')
318+
self.assertEqual(a.name, 'a')
319+
self.assertEqual(w.name, 'w')
320+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
321+
self.assertEqual(pyrtl.working_block().wirevector_set, {a})
322+
self.assertIs(b.wirevector_by_name['w'], w)
323+
self.assertEqual(b.wirevector_set, {w})
324+
260325

261326
# this code needs mocking from python 3's unittests to work
262327
"""

0 commit comments

Comments
 (0)