Skip to content

Commit ad2c54e

Browse files
authored
Merge pull request #345 from pllab/synthesis-update
Fix post-synth passes test; add reg_map for tracking registers post-synthesis
2 parents 2009957 + 8eb694d commit ad2c54e

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

pyrtl/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,11 +730,21 @@ def sanity_check_net(self, net):
730730
class PostSynthBlock(Block):
731731
""" This is a block with extra metadata required to maintain the
732732
pre-synthesis interface during post-synthesis.
733+
734+
It currently holds the following instance attributes:
735+
736+
* *.io_map*: a map from old IO wirevector to a list of new IO wirevectors it maps to;
737+
this is a list because for unmerged io vectors, each old N-bit IO wirevector maps
738+
to N new 1-bit IO wirevectors.
739+
* *.reg_map*: a map from old register to a list of new registers; a list because post-synthesis,
740+
each N-bit register has been mapped to N 1-bit registers
741+
* *.mem_map*: a map from old memory block to the new memory block
733742
"""
734743

735744
def __init__(self):
736745
super(PostSynthBlock, self).__init__()
737-
self.io_map = {}
746+
self.io_map = collections.defaultdict(list)
747+
self.reg_map = collections.defaultdict(list)
738748
self.mem_map = {}
739749

740750

pyrtl/passes.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
1616
from .wire import WireVector, Input, Output, Const, Register
1717
from .transform import net_transform, _get_new_block_mem_instance, copy_block, replace_wires
18-
from . import transform # transform.all_nets looks better than all_nets
18+
from . import transform
19+
from pyrtl import wire # transform.all_nets looks better than all_nets
1920

2021

2122
# --------------------------------------------------------------------
@@ -411,6 +412,7 @@ def synthesize(update_working_block=True, merge_io_vectors=True, block=None):
411412
# to the original io wirevector found in block_pre. We use it to create
412413
# the block_out.io_map that is returned to the user.
413414
orig_io_map = {temp: orig for orig, temp in block_in.io_map.items()}
415+
orig_reg_map = {temp: orig for orig, temp in block_in.reg_map.items()}
414416

415417
# Next, create all of the new wires for the new block
416418
# from the original wires and store them in the wirevector_map
@@ -430,9 +432,11 @@ def synthesize(update_working_block=True, merge_io_vectors=True, block=None):
430432
if len(wirevector) > 1:
431433
new_name += '[' + str(i) + ']'
432434
new_wirevector = wirevector.__class__(name=new_name, bitwidth=1)
433-
block_out.io_map[orig_io_map[wirevector]] = new_wirevector
435+
block_out.io_map[orig_io_map[wirevector]].append(new_wirevector)
434436
else:
435437
new_wirevector = wirevector.__class__(name=new_name, bitwidth=1)
438+
if isinstance(wirevector, Register):
439+
block_out.reg_map[orig_reg_map[wirevector]].append(new_wirevector)
436440
wirevector_map[(wirevector, i)] = new_wirevector
437441

438442
# Now connect up the inputs and outputs to maintain the interface
@@ -441,12 +445,12 @@ def synthesize(update_working_block=True, merge_io_vectors=True, block=None):
441445
input_vector = Input(name=wirevector.name, bitwidth=len(wirevector))
442446
for i in range(len(wirevector)):
443447
wirevector_map[(wirevector, i)] <<= input_vector[i]
444-
block_out.io_map[orig_io_map[wirevector]] = [input_vector]
448+
block_out.io_map[orig_io_map[wirevector]].append(input_vector)
445449
for wirevector in block_in.wirevector_subset(Output):
446450
output_vector = Output(name=wirevector.name, bitwidth=len(wirevector))
447451
output_bits = [wirevector_map[(wirevector, i)] for i in range(len(output_vector))]
448452
output_vector <<= concat_list(output_bits)
449-
block_out.io_map[orig_io_map[wirevector]] = [output_vector]
453+
block_out.io_map[orig_io_map[wirevector]].append(output_vector)
450454

451455
# Now that we have all the wires built and mapped, walk all the blocks
452456
# and map the logic to the equivalent set of primitives in the system

pyrtl/transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def copy_block(block=None, update_working_block=True):
195195
_copy_net(block_out, net, temp_wv_map, mems)
196196
block_out.mem_map = mems
197197
block_out.io_map = {io: w for io, w in temp_wv_map.items() if isinstance(io, (Input, Output))}
198+
block_out.reg_map = {r: w for r, w, in temp_wv_map.items() if isinstance(r, Register)}
198199

199200
if update_working_block:
200201
set_working_block(block_out)

tests/test_passes.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ def test_mux_simulation(self):
6969
self.r.next <<= pyrtl.mux(self.r, 4, 3, 1, 7, 2, 6, 0, 5)
7070
self.check_trace('r 04213756\n')
7171

72+
def test_synthesize_regs_mapped_correctly(self):
73+
r2 = pyrtl.Register(5)
74+
self.r.next <<= ~ self.r
75+
r2.next <<= self.r + 1
76+
synth_block = pyrtl.synthesize()
77+
self.assertEqual(len(synth_block.reg_map), 2)
78+
self.assertEqual(len(synth_block.reg_map[self.r]), len(self.r))
79+
self.assertEqual(len(synth_block.reg_map[r2]), len(r2))
80+
7281

7382
class TestIOInterfaceSynthesis(unittest.TestCase):
7483
def setUp(self):
@@ -104,9 +113,11 @@ def test_synthesize_merged_io_mapped_correctly(self):
104113
pyrtl.synthesize()
105114
new_io = pyrtl.working_block().wirevector_subset((pyrtl.Input, pyrtl.Output))
106115
for oi in old_io:
116+
io_list = pyrtl.working_block().io_map[oi]
117+
self.assertEqual(len(io_list), 1)
107118
for ni in new_io:
108119
if oi.name == ni.name:
109-
self.assertEqual(pyrtl.working_block().io_map[oi], [ni])
120+
self.assertEqual(io_list, [ni])
110121

111122
def test_synthesize_merged_io_simulates_correctly(self):
112123
pyrtl.synthesize()
@@ -130,12 +141,16 @@ def test_synthesize_unmerged_io_names_correct(self):
130141

131142
def test_synthesize_unmerged_io_mapped_correctly(self):
132143
old_io = pyrtl.working_block().wirevector_subset((pyrtl.Input, pyrtl.Output))
133-
pyrtl.synthesize()
144+
pyrtl.synthesize(merge_io_vectors=False)
134145
new_io = pyrtl.working_block().wirevector_subset((pyrtl.Input, pyrtl.Output))
135146
for oi in old_io:
147+
io_list = [w.name for w in pyrtl.working_block().io_map[oi]]
148+
self.assertEqual(len(io_list), len(oi))
136149
for ni in new_io:
137150
if ni.name.startswith(oi.name):
138-
self.assertIn(ni, pyrtl.working_block().io_map[oi])
151+
# Dev note: comparing names because comparing wires (e.g. list/set inclusion)
152+
# creates an '=' net, which is definitely not what we want here.
153+
self.assertIn(ni.name, io_list)
139154

140155
def test_synthesize_unmerged_io_simulates_correctly(self):
141156
pyrtl.synthesize(merge_io_vectors=False)

0 commit comments

Comments
 (0)