Skip to content

Commit a51f793

Browse files
committed
Add optimization pass that removes unneeded slices
1 parent d931e40 commit a51f793

File tree

2 files changed

+111
-7
lines changed

2 files changed

+111
-7
lines changed

pyrtl/passes.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False):
4646
if (not skip_sanity_check) or _get_debug_mode():
4747
block.sanity_check()
4848
_remove_wire_nets(block)
49+
_remove_slice_nets(block)
4950
constant_propagation(block, True)
5051
_remove_unlistened_nets(block)
5152
common_subexp_elimination(block)
@@ -103,6 +104,67 @@ def _remove_wire_nets(block):
103104
block.sanity_check()
104105

105106

107+
def _remove_slice_nets(block):
108+
""" Remove all unneeded slice nodes from the block.
109+
110+
Unneeded here means that the source and destination wires of a slice net are exactly
111+
the same, because the slice takes all the bits, in order, from the source.
112+
"""
113+
# Turns a net of form on the left into the one on the right:
114+
#
115+
# w1
116+
# |
117+
# [3:0]
118+
# |
119+
# [3:0] ===> w1
120+
# | |
121+
# [3:0] w2 [3:0] w2
122+
# / \ / / \ /
123+
# ~ + ~ +
124+
# | | | |
125+
126+
wire_src_dict = _ProducerList()
127+
wire_removal_set = set() # set of all wirevectors to be removed
128+
129+
def is_net_slicing_entire_wire(net):
130+
if net.op != 's':
131+
return False
132+
133+
src_wire = net.args[0]
134+
dst_wire = net.dests[0]
135+
if len(src_wire) != len(dst_wire):
136+
return False
137+
138+
selLower = net.op_param[0]
139+
selUpper = net.op_param[-1]
140+
# Check if getting all bits from the src_wire (i.e. consecutive bits, MSB to LSB)
141+
return net.op_param == tuple(range(selLower, selUpper + 1))
142+
143+
# one pass to build the map of value producers and
144+
# all of the nets and wires to be removed
145+
for net in block.logic:
146+
if is_net_slicing_entire_wire(net):
147+
wire_src_dict[net.dests[0]] = net.args[0]
148+
if not isinstance(net.dests[0], Output):
149+
wire_removal_set.add(net.dests[0])
150+
151+
# second full pass to create the new logic without the wire nets
152+
new_logic = set()
153+
for net in block.logic:
154+
if not is_net_slicing_entire_wire(net) or isinstance(net.dests[0], Output):
155+
new_args = tuple(wire_src_dict.find_producer(x) for x in net.args)
156+
new_net = LogicNet(net.op, net.op_param, new_args, net.dests)
157+
new_logic.add(new_net)
158+
159+
# now update the block with the new logic and remove wirevectors
160+
block.logic = new_logic
161+
for dead_wirevector in wire_removal_set:
162+
del block.wirevector_by_name[dead_wirevector.name]
163+
block.wirevector_set.remove(dead_wirevector)
164+
165+
block.sanity_check()
166+
167+
106168
def constant_propagation(block, silence_unexpected_net_warnings=False):
107169
""" Removes excess constants in the block.
108170

tests/test_passes.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ def test_wire_net_removal_1(self):
253253
outwire = pyrtl.Output()
254254
tempwire <<= inwire
255255
outwire <<= tempwire
256-
257256
pyrtl.synthesize()
258257
pyrtl.optimize()
259258
block = pyrtl.working_block()
@@ -276,6 +275,49 @@ def test_wire_net_removal_2(self):
276275
self.assert_num_net(5, block)
277276
self.assert_num_wires(6, block)
278277

278+
def test_slice_net_removal_1(self):
279+
constwire = pyrtl.Const(1, 1)
280+
inwire = pyrtl.Input(bitwidth=1)
281+
outwire = pyrtl.Output()
282+
outwire <<= constwire ^ inwire
283+
pyrtl.optimize()
284+
block = pyrtl.working_block()
285+
self.num_net_of_type('s', 0, block)
286+
self.num_net_of_type('~', 1, block)
287+
288+
def test_slice_net_removal_2(self):
289+
inwire = pyrtl.Input(bitwidth=3)
290+
outwire = pyrtl.Output()
291+
tempwire = inwire[0:3]
292+
outwire <<= tempwire[0:3]
293+
pyrtl.optimize()
294+
block = pyrtl.working_block()
295+
self.num_net_of_type('s', 0, block)
296+
self.num_net_of_type('w', 1, block)
297+
298+
def test_slice_net_removal_3(self):
299+
inwire = pyrtl.Input(bitwidth=3)
300+
outwire = pyrtl.Output()
301+
tempwire = inwire[0:2]
302+
outwire <<= tempwire[0:2]
303+
pyrtl.optimize()
304+
# Removes one of the slices, which does nothing.
305+
block = pyrtl.working_block()
306+
self.num_net_of_type('s', 1, block)
307+
self.num_net_of_type('w', 1, block)
308+
309+
def test_slice_net_removal_4(self):
310+
inwire = pyrtl.Input(bitwidth=4)
311+
outwire1 = pyrtl.Output()
312+
outwire2 = pyrtl.Output()
313+
outwire1 <<= inwire[0:4]
314+
outwire2 <<= inwire[0:3]
315+
pyrtl.optimize()
316+
# Removes just the outwire1 slice, which does nothing.
317+
block = pyrtl.working_block()
318+
self.num_net_of_type('s', 1, block)
319+
self.num_net_of_type('w', 2, block)
320+
279321

280322
class TestConstFolding(NetWireNumTestCases):
281323
def setUp(self):
@@ -343,8 +385,8 @@ def test_adv_one_var_op_2(self):
343385
# Note: the current implementation still sticks a wire net between
344386
# a register 'nextsetter' wire and the output wire
345387
self.num_net_of_type('w', 1, block)
346-
self.assert_num_net(4, block)
347-
self.assert_num_wires(5, block)
388+
self.assert_num_net(3, block)
389+
self.assert_num_wires(4, block)
348390
self.num_wire_of_type(Const, 0, block)
349391
self.num_wire_of_type(Output, 1, block)
350392

@@ -412,9 +454,9 @@ def test_two_var_op_produce_not(self):
412454
block = pyrtl.working_block(None)
413455
self.num_net_of_type('~', 1, block)
414456
self.num_net_of_type('w', 1, block)
415-
self.num_net_of_type('s', 1, block) # due to synthesis
416-
self.assert_num_net(3, block)
417-
self.assert_num_wires(4, block)
457+
self.num_net_of_type('s', 0, block)
458+
self.assert_num_net(2, block)
459+
self.assert_num_wires(3, block)
418460
self.num_wire_of_type(Const, 0, block)
419461

420462
def test_two_var_op_correct_wire_prop(self):
@@ -703,7 +745,7 @@ def test_wirevector_1(self):
703745
outwire <<= ~tempwire2
704746
self.everything_t_procedure(48.5, 48.5)
705747
block = pyrtl.working_block()
706-
self.assert_num_net(3, block)
748+
self.assert_num_net(2, block)
707749

708750
def test_combo_1(self):
709751
inwire, inwire2 = pyrtl.Input(bitwidth=1), pyrtl.Input(bitwidth=1)

0 commit comments

Comments
 (0)