Skip to content

Commit 6b0e0b0

Browse files
committed
Improve paths() function, more tests, small update on trailing whitespace when printing a net
1 parent b37df27 commit 6b0e0b0

File tree

3 files changed

+189
-8
lines changed

3 files changed

+189
-8
lines changed

pyrtl/analysis.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,34 @@ def extract_area_delay_from_yosys_output(yosys_output):
407407
return area, delay
408408

409409

410+
class PathsResult(dict):
411+
def print(self, file=sys.stdout):
412+
""" Pretty print the result of calling paths()
413+
414+
:param f: the open file to print to (defaults to stdout)
415+
:return: None
416+
"""
417+
# All this work, to make sure it's determinstic
418+
def path_sort_key(path):
419+
dst_names = [net.dests[0].name if net.dests else '' for net in path]
420+
return (len(path), dst_names)
421+
422+
for start in sorted(self.keys(), key=lambda w: w.name):
423+
print("From %s" % start.name, file=file)
424+
for end in sorted(self[start].keys(), key=lambda w: w.name):
425+
print(" To %s" % end.name, file=file)
426+
paths = self[start][end]
427+
if len(paths) > 0:
428+
for i, paths in enumerate(sorted(paths, key=path_sort_key)):
429+
print(" Path %d" % i, file=file)
430+
for path in paths:
431+
print(" %s" % str(path), file=file)
432+
else:
433+
print(" (No paths)", file=file)
434+
435+
410436
def paths(src=None, dst=None, dst_nets=None, block=None):
411-
""" Get the list of paths from src to dst.
437+
""" Get the list of all paths from src to dst.
412438
413439
:param Union[WireVector, Iterable[WireVector]] src: source wire(s) from which to
414440
trace your paths; if None, will get paths from all Inputs
@@ -420,7 +446,8 @@ def paths(src=None, dst=None, dst_nets=None, block=None):
420446
:param Block block: block to use (defaults to working block)
421447
:return: a map of the form {src_wire: {dst_wire: [path]}} for each src_wire in src
422448
(or all inputs if src is None), dst_wire in dst (or all outputs if dst is None),
423-
where path is a list of nets
449+
where path is a list of nets. This map is also an instance of PathsResult,
450+
so you can call `print()` on it to pretty print it.
424451
425452
You can provide dst_nets (the result of calling pyrtl.net_connections()), if you plan
426453
on calling this function repeatedly on a block that hasn't changed, to speed things up.
@@ -434,6 +461,11 @@ def paths(src=None, dst=None, dst_nets=None, block=None):
434461
to a given dst wire.
435462
436463
If src and dst are both single wires, you still need to access the result via paths[src][dst].
464+
465+
This also finds and returns the loop paths in the case of registers or memories that feed into
466+
themselves, i.e. paths[src][src] is not necessarily empty.
467+
468+
It does not distinguish between loops that include synchronous vs asynchronous memories.
437469
"""
438470
block = working_block(block)
439471

@@ -472,17 +504,36 @@ def dfs(w, curr_path):
472504
paths.append(curr_path)
473505
for dst_net in dst_nets.get(w, []):
474506
# Avoid loops and the mem net (has no output wire)
475-
if (dst_net not in curr_path) and (dst_net.op != '@'):
476-
dfs(dst_net.dests[0], curr_path + [dst_net])
507+
if dst_net not in curr_path:
508+
if dst_net.op == '@': # dests will be the read ports
509+
for read_net in dst_net.op_param[1].readport_nets:
510+
dfs(read_net.dests[0], curr_path + [dst_net, read_net])
511+
else:
512+
dfs(dst_net.dests[0], curr_path + [dst_net])
477513
dfs(src, [])
478514
return paths
479515

480516
all_paths = collections.defaultdict(dict)
481517
for src_wire in src:
482518
for dst_wire in dst:
483-
all_paths[src_wire][dst_wire] = paths_src_dst(src_wire, dst_wire)
484-
485-
return all_paths
519+
paths = paths_src_dst(src_wire, dst_wire)
520+
# Remove empty paths...
521+
paths = list(filter(lambda x: len(x) > 0, paths))
522+
# ...and those that are supersets of others (resulting from an inner loop).
523+
if src_wire is not dst_wire:
524+
paths = sorted(paths, key=lambda p: len(p), reverse=True)
525+
keep = []
526+
for i in range(len(paths)):
527+
# Check if there is a path in paths[i+1:] that is the suffix
528+
# of paths[i] (paths[i] is at least as large as each path in
529+
# paths[i+1:]). If so, paths[i] contains a loop since both start
530+
# at src_wire, so don't keep it.
531+
if not any(paths[i][-len(p):] == p for p in paths[i + 1:]):
532+
keep.append(paths[i])
533+
paths = keep
534+
all_paths[src_wire][dst_wire] = paths
535+
536+
return PathsResult(all_paths)
486537

487538

488539
def distance(src, dst, f, block=None):

pyrtl/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def __str__(self):
123123

124124
else: # not in ipython
125125
if self.op in 'w~&|^n+-*<>=xcsr':
126-
return "{} <-- {} -- {} {}".format(lhs, self.op, rhs, options)
126+
options = ' ' + options if options else ''
127+
return "{} <-- {} -- {}{}".format(lhs, self.op, rhs, options)
127128
elif self.op in 'm@':
128129
memid, memblock = self.op_param
129130
extrainfo = 'memid=' + str(memid)

tests/test_analysis.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import print_function, unicode_literals, absolute_import
22

33
import unittest
4+
import io
45
import pyrtl
56

67

@@ -113,10 +114,50 @@ def setUp(self):
113114
pyrtl.reset_working_block()
114115

115116

117+
paths_print_output = """\
118+
From i
119+
To o
120+
Path 0
121+
tmp5/3W <-- - -- i/2I, tmp4/2W
122+
tmp6/3W <-- | -- tmp2/3W, tmp5/3W
123+
o/3O <-- w -- tmp6/3W
124+
Path 1
125+
tmp1/3W <-- c -- tmp0/1W, i/2I
126+
tmp2/3W <-- & -- tmp1/3W, j/3I
127+
tmp6/3W <-- | -- tmp2/3W, tmp5/3W
128+
o/3O <-- w -- tmp6/3W
129+
To p
130+
Path 0
131+
tmp8/4W <-- c -- tmp7/2W, i/2I
132+
tmp9/5W <-- - -- k/4I, tmp8/4W
133+
p/5O <-- w -- tmp9/5W
134+
From j
135+
To o
136+
Path 0
137+
tmp2/3W <-- & -- tmp1/3W, j/3I
138+
tmp6/3W <-- | -- tmp2/3W, tmp5/3W
139+
o/3O <-- w -- tmp6/3W
140+
To p
141+
(No paths)
142+
From k
143+
To o
144+
(No paths)
145+
To p
146+
Path 0
147+
tmp9/5W <-- - -- k/4I, tmp8/4W
148+
p/5O <-- w -- tmp9/5W
149+
"""
150+
151+
116152
class TestPaths(unittest.TestCase):
117153

118154
def setUp(self):
119155
pyrtl.reset_working_block()
156+
# To compare textual consistency, need to make
157+
# sure we're starting at the same index for all
158+
# automatically created names.
159+
pyrtl.wire._reset_wire_indexers()
160+
pyrtl.memory._reset_memory_indexer()
120161

121162
def test_one_path_to_one_output(self):
122163
a = pyrtl.Input(4, 'a')
@@ -181,6 +222,83 @@ def test_subset_of_all_paths(self):
181222
self.assertNotIn(p, paths_from_k) # Because p was not provided as target output
182223
self.assertEqual(len(paths_from_k[o]), 0) # 0 paths from k to o
183224

225+
def test_paths_empty_src_and_dst_equal_with_no_other_logic(self):
226+
i = pyrtl.Input(4, 'i')
227+
paths = pyrtl.paths(i, i)
228+
self.assertEqual(len(paths[i][i]), 0)
229+
230+
def test_paths_with_loop(self):
231+
r = pyrtl.Register(1, 'r')
232+
r.next <<= r & ~r
233+
paths = pyrtl.paths(r, r)
234+
self.assertEqual(len(paths[r][r]), 2)
235+
p1, p2 = sorted(paths[r][r], key=lambda p: len(p), reverse=True)
236+
self.assertEqual(len(p1), 3)
237+
self.assertEqual(p1[0].op, '~')
238+
self.assertEqual(p1[1].op, '&')
239+
self.assertEqual(p1[2].op, 'r')
240+
self.assertEqual(len(p2), 2)
241+
self.assertEqual(p2[0].op, '&')
242+
self.assertEqual(p2[1].op, 'r')
243+
244+
def test_paths_loop_and_input(self):
245+
i = pyrtl.Input(1, 'i')
246+
o = pyrtl.Output(1, 'o')
247+
r = pyrtl.Register(1, 'r')
248+
r.next <<= i & r
249+
o <<= r
250+
paths = pyrtl.paths(r, o)
251+
self.assertEqual(len(paths[r][o]), 1)
252+
253+
def test_paths_loop_get_arbitrary_inner_wires(self):
254+
w = pyrtl.WireVector(1, 'w')
255+
y = w & pyrtl.Const(1)
256+
w <<= ~y
257+
paths = pyrtl.paths(w, y)
258+
self.assertEqual(len(paths[w][y]), 1)
259+
self.assertEqual(paths[w][y][0][0].op, '&')
260+
261+
def test_paths_no_path_exists(self):
262+
i = pyrtl.Input(1, 'i')
263+
o = pyrtl.Output(1, 'o')
264+
o <<= ~i
265+
266+
w = pyrtl.WireVector(1, 'w')
267+
y = w & pyrtl.Const(1)
268+
w <<= ~y
269+
270+
paths = pyrtl.paths(w, o)
271+
self.assertEqual(len(paths[w][o]), 0)
272+
273+
def test_paths_with_memory(self):
274+
i = pyrtl.Input(4, 'i')
275+
o = pyrtl.Output(8, 'o')
276+
mem = pyrtl.MemBlock(8, 32, 'mem')
277+
waddr = pyrtl.Input(32, 'waddr')
278+
raddr = pyrtl.Input(32, 'raddr')
279+
data = mem[raddr]
280+
mem[waddr] <<= (i + ~data).truncate(8)
281+
o <<= data
282+
283+
paths = pyrtl.paths(i, o)
284+
path = paths[i][o][0]
285+
self.assertEqual(path[0].op, 'c')
286+
self.assertEqual(path[1].op, '+')
287+
self.assertEqual(path[2].op, 's')
288+
self.assertEqual(path[3].op, '@')
289+
self.assertEqual(path[4].op, 'm')
290+
self.assertEqual(path[5].op, 'w')
291+
292+
# TODO Once issue with _MemIndexed lookups is resolved,
293+
# these should be `data` instead of `data.wire`.
294+
paths = pyrtl.paths(data.wire, data.wire)
295+
path = paths[data.wire][data.wire][0]
296+
self.assertEqual(path[0].op, '~')
297+
self.assertEqual(path[1].op, '+')
298+
self.assertEqual(path[2].op, 's')
299+
self.assertEqual(path[3].op, '@')
300+
self.assertEqual(path[4].op, 'm')
301+
184302
def test_all_paths(self):
185303
a, b, c = pyrtl.input_list('a/2 b/4 c/1')
186304
o, p = pyrtl.output_list('o/4 p/2')
@@ -218,6 +336,17 @@ def test_all_paths(self):
218336
paths_c_to_p = paths[c][p]
219337
self.assertEqual(len(paths_c_to_p), 1)
220338

339+
def test_pretty_print(self):
340+
i, j, k = pyrtl.input_list('i/2 j/3 k/4')
341+
o, p = pyrtl.Output(name='o'), pyrtl.Output(name='p')
342+
o <<= (i & j) | (i - 1)
343+
p <<= k - i
344+
345+
paths = pyrtl.paths()
346+
output = io.StringIO()
347+
paths.print(file=output)
348+
self.assertEqual(output.getvalue(), paths_print_output)
349+
221350

222351
class TestDistance(unittest.TestCase):
223352

0 commit comments

Comments
 (0)