Skip to content

Commit e3b1a30

Browse files
authored
Merge pull request #402 from pllab/paths_update
Make paths() function match docs
2 parents e59eaef + b37df27 commit e3b1a30

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

pyrtl/analysis.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,10 @@ def extract_area_delay_from_yosys_output(yosys_output):
410410
def paths(src=None, dst=None, dst_nets=None, block=None):
411411
""" Get the list of paths from src to dst.
412412
413-
:param WireVector src: source wire(s) from which to trace your paths;
414-
if None, will get paths from all Inputs
415-
:param WireVector dst: destination wire(s) to which to trace your paths
416-
if None, will get paths to all Outputs
413+
:param Union[WireVector, Iterable[WireVector]] src: source wire(s) from which to
414+
trace your paths; if None, will get paths from all Inputs
415+
:param Union[WireVector, Iterable[WireVector]] dst: destination wire(s) to which to
416+
trace your paths; if None, will get paths to all Outputs
417417
:param {WireVector: {LogicNet}} dst_nets: map from wire to set of nets where the
418418
wire is an argument; will compute it internally if not given via a
419419
call to pyrtl.net_connections()
@@ -448,10 +448,21 @@ def paths(src=None, dst=None, dst_nets=None, block=None):
448448
for output in block.wirevector_subset(cls=Output):
449449
dst_nets.pop(output, None)
450450

451-
src = block.wirevector_subset(cls=Input) if src is None else {src}
452-
dst = block.wirevector_subset(cls=Output) if dst is None else {dst}
451+
if src is None:
452+
src = block.wirevector_subset(cls=Input)
453+
elif isinstance(src, WireVector):
454+
src = {src}
455+
else:
456+
src = set(src)
457+
458+
if dst is None:
459+
dst = block.wirevector_subset(cls=Output)
460+
elif isinstance(dst, WireVector):
461+
dst = {dst}
462+
else:
463+
dst = set(dst)
453464

454-
def paths_src_dst(src, dst, block=None):
465+
def paths_src_dst(src, dst):
455466
paths = []
456467

457468
# Use DFS to get the paths [each a list of nets] from src wire to dst wire

tests/test_analysis.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,25 @@ def test_two_paths_to_two_outputs(self):
162162
path_to_o2 = paths_from_a[o2]
163163
self.assertEqual(len(path_to_o2), 1)
164164

165+
def test_subset_of_all_paths(self):
166+
i, j, k = pyrtl.input_list('i/2 j/3 k/4')
167+
o, p = pyrtl.Output(), pyrtl.Output()
168+
o <<= i & j
169+
p <<= k - i
170+
171+
# Make sure passing in both set and list works
172+
paths = pyrtl.paths([i, k], {o})
173+
paths_from_i = paths[i]
174+
self.assertNotIn(p, paths_from_i) # Because p was not provided as target output
175+
self.assertEqual(len(paths_from_i[o]), 1) # One path from i to o
176+
self.assertEqual(paths_from_i[o][0][0].op, 'c')
177+
self.assertEqual(paths_from_i[o][0][1].op, '&')
178+
self.assertEqual(paths_from_i[o][0][2].op, 'w')
179+
180+
paths_from_k = paths[k]
181+
self.assertNotIn(p, paths_from_k) # Because p was not provided as target output
182+
self.assertEqual(len(paths_from_k[o]), 0) # 0 paths from k to o
183+
165184
def test_all_paths(self):
166185
a, b, c = pyrtl.input_list('a/2 b/4 c/1')
167186
o, p = pyrtl.output_list('o/4 p/2')

0 commit comments

Comments
 (0)