Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions pyrtl/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,84 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False):
constant_propagation(block, True)
_remove_unlistened_nets(block)
common_subexp_elimination(block)
_remove_double_inverts(block, skip_sanity_check)
if (not skip_sanity_check) or _get_debug_mode():
block.sanity_check()
return block


def _remove_double_inverts(block, skip_sanity_check=False):
""" Removes all double invert nets from the block. """

# checks if the wirevector is used at a LogicNet other than used_at_nets
def is_wirevector_used_elsewhere(wire, used_at_nets):
for net in block.logic:
if net not in used_at_nets:
if wire.name in [x.name for x in net.args] \
or wire.name in [x.name for x in net.dests]:
return True
return False

new_logic = set()
net_exclude_set = set() # removed nets
wire_removal_set = set()
# Dictionary, key is the destination wire of the invert net, value is the invert net
invert_destination_wires = {}
for net in block.logic:
if net.op == "~":
invert_destination_wires[net.dests[0].name] = net
# If double invert nets are removed randomly, this may leave some double inverts behind.
# Example: ~(~(~(~a)))
# If we remove the middle two inverts first, we will end up with ~((~a)). These remaining
# double inverts won't get removed because they aren't directly connected.
# To avoid this, we remove double inverts in a chain sequentially from start to end.
# For example, we first remove the two outer inverts from ~(~(~(~a))) to get ~(~a),
# and then remove the remaining two inner inverts. To do this, we need iterate through
# the invert_destination_wires dictionary multiple times, hence the outer while loop.
repeat = True
while repeat:
repeat = False
removed_nets = set()
for net in invert_destination_wires.values():
# If the argument of the net is in invert_destination_wires, then it is a double invert
# If the net is in net_exclude_set, then it was already removed, so we do not process it
if net.args[0].name in invert_destination_wires and net not in net_exclude_set:
previous_net = invert_destination_wires[net.args[0].name]
if not is_wirevector_used_elsewhere(net.args[0], (net, previous_net)) \
and previous_net not in net_exclude_set:
# If previous_net is in invert_destination_wires, we have a chain of
# 3 or more double inverts. To make sure we remove double inverts
# in these chains sequentially, we only remove the double invert
# we found if the invert net whose destination is previous_net
# was not removed yet. If it was not yet removed, the for loop
# needs to run again, so we set repeat to True.
if previous_net.args[0].name in invert_destination_wires:
repeat = True
else:
new_logic.add(LogicNet('w', None, args=previous_net.args, dests=net.dests))
wire_removal_set.add(net.args[0])
removed_nets.add(net)
removed_nets.add(previous_net)
# remove removed_nets from invert_destination_wires to optimize the for loop
for net in removed_nets:
del invert_destination_wires[net.dests[0].name]
net_exclude_set.update(removed_nets)

for net in block.logic:
if net not in net_exclude_set:
new_logic.add(net)

block.logic = new_logic
for dead_wirevector in wire_removal_set:
block.remove_wirevector(dead_wirevector)

if (not skip_sanity_check) or _get_debug_mode():
block.sanity_check()

# clean up wire nodes
_remove_wire_nets(block, skip_sanity_check)


class _ProducerList(object):
""" Maps from wire to its immediate producer and finds ultimate producers. """
def __init__(self):
Expand Down
81 changes: 81 additions & 0 deletions tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,87 @@ def test_slice_net_removal_4(self):
self.num_net_of_type('s', 1, block)
self.num_net_of_type('w', 2, block)

def test_remove_double_inverts_1_invert(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~inwire
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_3_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~inwire))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_5_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~(~inwire))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_2_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~inwire)
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_remove_double_inverts_4_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~inwire)))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_remove_double_inverts_6_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~(~(~inwire)))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_dont_remove_double_inverts_another_user(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire2 = pyrtl.Output(bitwidth=1)
tempwire = pyrtl.WireVector()
tempwire <<= ~inwire
outwire <<= ~tempwire
outwire2 <<= tempwire
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(4, block)
self.assert_num_wires(5, block)

def test_multiple_double_invert_chains(self):
# _remove_double_inverts removes double inverts by chains,
# so it is useful to make sure it can remove
# double inverts from multiple chains
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire2 = pyrtl.Output(bitwidth=1)
outwire <<= ~(~inwire)
outwire2 <<= ~(~(~(~(inwire))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)


class TestConstFolding(NetWireNumTestCases):
def setUp(self):
Expand Down
Loading