Skip to content

Commit 4ea798f

Browse files
committed
Add sanity checks for wirevector_by_name
This revealed a few bugs in passes where the wirevector_by_name dictionary was not being updated. Also updated code’s developer notes for clarity.
1 parent 4b17076 commit 4ea798f

File tree

3 files changed

+49
-11
lines changed

3 files changed

+49
-11
lines changed

pyrtl/core.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def add_wirevector(self, wirevector):
279279
self.wirevector_by_name[wirevector.name] = wirevector
280280

281281
def remove_wirevector(self, wirevector):
282-
""" Remove a wirevector object to the block."""
282+
""" Remove a wirevector object from the block."""
283283
self.wirevector_set.remove(wirevector)
284284
del self.wirevector_by_name[wirevector.name]
285285

@@ -505,7 +505,6 @@ def sanity_check(self):
505505
built according to the assumptions stated in the Block comments.
506506
"""
507507

508-
# TODO: check that the wirevector_by_name is sane
509508
from .wire import Input, Const, Output
510509
from .helperfuncs import get_stack, get_stacks
511510

@@ -529,9 +528,6 @@ def sanity_check(self):
529528
'or "const_" as a signal name because those are reserved for '
530529
'internal use)' % repr(wirevector_names_list))
531530

532-
# check for dead input wires (not connected to anything)
533-
all_input_and_consts = self.wirevector_subset((Input, Const))
534-
535531
# The following line also checks for duplicate wire drivers
536532
wire_src_dict, wire_dst_dict = self.net_connections()
537533
dest_set = set(wire_src_dict.keys())
@@ -542,9 +538,12 @@ def sanity_check(self):
542538
bad_wire_names = '\n '.join(str(x) for x in connected_minus_allwires)
543539
raise PyrtlError('Unknown wires found in net:\n %s \n\n %s' % (bad_wire_names,
544540
get_stacks(*connected_minus_allwires)))
541+
542+
all_input_and_consts = self.wirevector_subset((Input, Const))
543+
544+
# Check for wires that aren't connected to anything (inputs and consts can be unconnected)
545545
allwires_minus_connected = self.wirevector_set.difference(full_set)
546546
allwires_minus_connected = allwires_minus_connected.difference(all_input_and_consts)
547-
# ^ allow inputs and consts to be unconnected
548547
if len(allwires_minus_connected) > 0:
549548
bad_wire_names = '\n '.join(str(x) for x in allwires_minus_connected)
550549
raise PyrtlError('Wires declared but not connected:\n %s \n\n %s' % (bad_wire_names,
@@ -561,6 +560,24 @@ def sanity_check(self):
561560
# Check for async memories not specified as such
562561
self.sanity_check_memory_sync(wire_src_dict)
563562

563+
# Check that all mappings in wirevector_by_name are consistent
564+
bad_wv_by_name = [w for n, w in self.wirevector_by_name.items() if n != w.name]
565+
if bad_wv_by_name:
566+
raise PyrtlInternalError('Wires with inconsistent entry in wirevector_by_name '
567+
'dict: %s' % [w.name for w in bad_wv_by_name])
568+
569+
# Check that all wires are in wirevector_by_name
570+
wv_by_name_set = set(self.wirevector_by_name.keys())
571+
missing_wires = wirevector_names_set.difference(wv_by_name_set)
572+
if missing_wires:
573+
raise PyrtlInternalError('Missing entries in wirevector_by_name for the '
574+
'following wires: %s' % missing_wires)
575+
576+
unknown_wires = wv_by_name_set.difference(wirevector_names_set)
577+
if unknown_wires:
578+
raise PyrtlInternalError('Unknown wires found in wirevector_by_name: %s'
579+
% unknown_wires)
580+
564581
if debug_mode:
565582
# Check for wires that are destinations of a logicNet, but are not outputs and are never
566583
# used as args.

pyrtl/passes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ def _remove_wire_nets(block):
9898
# now update the block with the new logic and remove wirevectors
9999
block.logic = new_logic
100100
for dead_wirevector in wire_removal_set:
101-
del block.wirevector_by_name[dead_wirevector.name]
102-
block.wirevector_set.remove(dead_wirevector)
101+
block.remove_wirevector(dead_wirevector)
103102

104103
block.sanity_check()
105104

@@ -446,6 +445,7 @@ def _remove_unused_wires(block, keep_inputs=True):
446445
PyrtlInternalError("Output wire, " + removed_wire.name + " not driven")
447446

448447
block.wirevector_set = valid_wires
448+
block.wirevector_by_name = {wire.name: wire for wire in valid_wires}
449449

450450
# --------------------------------------------------------------------
451451
# __ ___ ___ __ __
@@ -825,7 +825,8 @@ def direct_connect_outputs(block=None):
825825

826826
block.logic.difference_update(nets_to_remove)
827827
block.logic.update(nets_to_add)
828-
block.wirevector_set.difference_update(wirevectors_to_remove)
828+
for w in wirevectors_to_remove:
829+
block.remove_wirevector(w)
829830

830831

831832
def _make_tree(wire, block, curr_fanout):

tests/test_core.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ class TestSanityCheck(unittest.TestCase):
476476
def setUp(self):
477477
pyrtl.reset_working_block()
478478

479-
def sanity_error(self, msg):
480-
with self.assertRaisesRegexp(pyrtl.PyrtlError, msg):
479+
def sanity_error(self, msg, error_type=pyrtl.PyrtlError):
480+
with self.assertRaisesRegexp(error_type, msg):
481481
pyrtl.working_block().sanity_check()
482482

483483
def test_missing_bitwidth(self):
@@ -510,6 +510,26 @@ def test_not_driven(self):
510510
out <<= w
511511
self.sanity_error("used but never driven")
512512

513+
def test_inconsistent_wirevector_by_name(self):
514+
c = pyrtl.Const(42)
515+
inp = pyrtl.Input(8, 'inp')
516+
out = pyrtl.Output(8, 'out')
517+
out <<= inp & c
518+
pyrtl.working_block().wirevector_by_name['inp'] = c
519+
self.sanity_error("inconsistent entry in wirevector_by_name", pyrtl.PyrtlInternalError)
520+
521+
def test_missing_wire_in_wirevector_by_name(self):
522+
inp = pyrtl.Input(8, 'inp')
523+
out = pyrtl.Output(8, 'out')
524+
out <<= inp
525+
del pyrtl.working_block().wirevector_by_name['inp']
526+
self.sanity_error("Missing entries in wirevector_by_name", pyrtl.PyrtlInternalError)
527+
528+
def test_extra_wire_in_wirevector_by_name(self):
529+
inp = pyrtl.Input(8, 'inp')
530+
pyrtl.working_block().wirevector_set.discard(inp)
531+
self.sanity_error("Unknown wires found in wirevector_by_name", pyrtl.PyrtlInternalError)
532+
513533

514534
class TestLogicNets(unittest.TestCase):
515535
def setUp(self):

0 commit comments

Comments
 (0)