Skip to content

Commit 3595dda

Browse files
authored
Merge pull request #349 from pllab/clone_update
Updating `clone_wire` to be safe when used on same block.
2 parents 4e82a69 + 13108c6 commit 3595dda

File tree

2 files changed

+93
-9
lines changed

2 files changed

+93
-9
lines changed

pyrtl/transform.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
structures (through Block.logic, block.Wirevector_set, etc).
2020
"""
2121
import functools
22+
from pyrtl.pyrtlexceptions import PyrtlError
2223

2324
from .core import set_working_block, LogicNet, working_block
2425
from .wire import Const, Input, Output, WireVector, Register
@@ -240,19 +241,25 @@ def clone_wire(old_wire, name=None):
240241
""" Makes a copy of any existing wire.
241242
242243
:param old_wire: The wire to clone
243-
:param name: A name for the new wire
244+
:param name: A name for the new wire (required if the old wire
245+
and newly cloned wire are part of the same block)
244246
245-
Note that this function is mainly intended to be used when the
246-
two wires are from different blocks. Making two wires with the
247-
same name in the same block is not allowed.
247+
This function is mainly intended to be used when the two wires are from different
248+
blocks. Making two wires with the same name in the same block is not allowed.
248249
"""
250+
if name is None:
251+
if working_block() is old_wire._block:
252+
raise PyrtlError("Must provide a name for the newly cloned wire "
253+
"when cloning within the same block.")
254+
name = old_wire.name
255+
256+
if name in working_block().wirevector_by_name:
257+
raise PyrtlError("Cannot give a newly cloned wire the same name "
258+
"as an existing wire.")
259+
249260
if isinstance(old_wire, Const):
250-
if name is None:
251-
return Const(old_wire.val, old_wire.bitwidth, name=old_wire.name)
252261
return Const(old_wire.val, old_wire.bitwidth, name=name)
253262
else:
254-
if name is None:
255-
return old_wire.__class__(old_wire.bitwidth, name=old_wire.name)
256263
return old_wire.__class__(old_wire.bitwidth, name=name)
257264

258265

tests/test_transform.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pyrtl.core import set_working_block
12
import unittest
23
import pyrtl
34
from pyrtl import transform
@@ -50,7 +51,8 @@ def insert_random_inversions(rate=0.5):
5051

5152
def randomly_replace(wire):
5253
if random.random() < rate:
53-
new_src, new_dst = transform.clone_wire(wire), transform.clone_wire(wire)
54+
new_src = transform.clone_wire(wire, pyrtl.wire.next_tempvar_name())
55+
new_dst = transform.clone_wire(wire, pyrtl.wire.next_tempvar_name())
5456
new_dst <<= ~new_src
5557
return new_src, new_dst
5658
return wire, wire
@@ -258,6 +260,81 @@ def test_replace_only_dst_wire(self):
258260
self.assertEqual(w4_dst_net.dests, w1_dst_net.dests)
259261

260262

263+
class TestCloning(unittest.TestCase):
264+
def setUp(self):
265+
pyrtl.reset_working_block()
266+
267+
def test_same_type(self):
268+
for ix, cls in enumerate([pyrtl.WireVector, pyrtl.Register, pyrtl.Input, pyrtl.Output]):
269+
w1 = cls(4, 'w%d' % ix)
270+
w2 = pyrtl.clone_wire(w1, 'y%d' % ix)
271+
self.assertIsInstance(w2, cls)
272+
self.assertEqual(w1.bitwidth, w2.bitwidth)
273+
274+
def test_clone_wire_no_name_same_block(self):
275+
a = pyrtl.WireVector(1, 'a')
276+
with self.assertRaises(pyrtl.PyrtlError) as error:
277+
pyrtl.clone_wire(a)
278+
self.assertEqual(
279+
str(error.exception),
280+
"Must provide a name for the newly cloned wire "
281+
"when cloning within the same block."
282+
)
283+
284+
def test_clone_wire_same_name_same_block(self):
285+
a = pyrtl.WireVector(1, 'a')
286+
with self.assertRaises(pyrtl.PyrtlError) as error:
287+
pyrtl.clone_wire(a, 'a')
288+
self.assertEqual(
289+
str(error.exception),
290+
"Cannot give a newly cloned wire the same name as an existing wire."
291+
)
292+
293+
def test_clone_wire_different_name_same_block(self):
294+
a = pyrtl.WireVector(1, 'a')
295+
self.assertEqual(a.name, 'a')
296+
self.assertEqual(pyrtl.working_block().wirevector_set, {a})
297+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
298+
299+
w = pyrtl.clone_wire(a, name='w')
300+
self.assertEqual(w.name, 'w')
301+
self.assertEqual(a.name, 'a')
302+
self.assertIs(pyrtl.working_block().wirevector_by_name['w'], w)
303+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
304+
self.assertEqual(pyrtl.working_block().wirevector_set, {a, w})
305+
306+
pyrtl.working_block().remove_wirevector(a)
307+
self.assertEqual(pyrtl.working_block().wirevector_set, {w})
308+
309+
def test_clone_wire_no_or_same_name_different_block(self):
310+
for clone_name in (None, 'a'):
311+
a = pyrtl.WireVector(1, 'a')
312+
b = pyrtl.Block()
313+
with pyrtl.set_working_block(b):
314+
w = pyrtl.clone_wire(a, name=clone_name)
315+
316+
self.assertEqual(a.name, 'a')
317+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
318+
self.assertEqual(pyrtl.working_block().wirevector_set, {a})
319+
320+
self.assertEqual(w.name, 'a')
321+
self.assertIs(b.wirevector_by_name['a'], w)
322+
self.assertEqual(b.wirevector_set, {w})
323+
pyrtl.reset_working_block()
324+
325+
def test_clone_wire_different_name_different_block(self):
326+
a = pyrtl.WireVector(1, 'a')
327+
b = pyrtl.Block()
328+
with set_working_block(b):
329+
w = pyrtl.clone_wire(a, 'w')
330+
self.assertEqual(a.name, 'a')
331+
self.assertEqual(w.name, 'w')
332+
self.assertIs(pyrtl.working_block().wirevector_by_name['a'], a)
333+
self.assertEqual(pyrtl.working_block().wirevector_set, {a})
334+
self.assertIs(b.wirevector_by_name['w'], w)
335+
self.assertEqual(b.wirevector_set, {w})
336+
337+
261338
# this code needs mocking from python 3's unittests to work
262339
"""
263340
@mock.patch('transform_examples.pyrtl.probe')

0 commit comments

Comments
 (0)