Skip to content

Commit e370e5a

Browse files
yadomkarOmkar Yadavrossbardschult
authored
Enhancements change default join trees 6947 (networkx#6948)
* change default for new function join_trees networkx#6947 formatting * regression fix * post review changes * Make new kwargs keyword-only. cf. networkxgh-6956. * shift label tests to test_basic and inline 1-liner function. --------- Co-authored-by: Omkar Yadav <[email protected]> Co-authored-by: Ross Barnowski <[email protected]> Co-authored-by: Dan Schult <[email protected]>
1 parent d74179a commit e370e5a

File tree

3 files changed

+53
-41
lines changed

3 files changed

+53
-41
lines changed

networkx/algorithms/tree/coding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _make_tree(sequence):
195195
# For a nonempty sequence, get the subtrees for each child
196196
# sequence and join all the subtrees at their roots. After
197197
# joining the subtrees, the root is node 0.
198-
return nx.tree.join([(_make_tree(child), 0) for child in sequence])
198+
return nx.tree.join_trees([(_make_tree(child), 0) for child in sequence])
199199

200200
# Make the tree and remove the `is_root` node attribute added by the
201201
# helper function.

networkx/algorithms/tree/operations.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ def join(rooted_trees, label_attribute=None):
3131
return join_trees(rooted_trees, label_attribute=label_attribute)
3232

3333

34-
def join_trees(rooted_trees, label_attribute=None):
34+
def join_trees(rooted_trees, *, label_attribute=None, first_label=0):
3535
"""Returns a new rooted tree made by joining `rooted_trees`
3636
3737
Constructs a new tree by joining each tree in `rooted_trees`.
3838
A new root node is added and connected to each of the roots
3939
of the input trees. While copying the nodes from the trees,
40-
relabeling to integers occurs and the old name stored as an
41-
attribute of the new node in the returned graph.
40+
relabeling to integers occurs. If the `label_attribute` is provided,
41+
the old node labels will be stored in the new tree under this attribute.
4242
4343
Parameters
4444
----------
@@ -50,17 +50,20 @@ def join_trees(rooted_trees, label_attribute=None):
5050
5151
label_attribute : str
5252
If provided, the old node labels will be stored in the new tree
53-
under this node attribute. If not provided, the node attribute
54-
``'_old'`` will store the original label of the node in the
55-
rooted trees given in the input.
53+
under this node attribute. If not provided, the original labels
54+
of the nodes in the input trees are not stored.
55+
56+
first_label : int, optional (default=0)
57+
Specifies the label for the new root node. If provided, the root node of the joined tree
58+
will have this label. If not provided, the root node will default to a label of 0.
5659
5760
Returns
5861
-------
5962
NetworkX graph
60-
The rooted tree whose subtrees are the given rooted trees. The
61-
new root node is labeled 0. Each non-root node has an attribute,
62-
as described under the keyword argument ``label_attribute``,
63-
that indicates the label of the original node in the input tree.
63+
The rooted tree resulting from joining the provided `rooted_trees`. The new tree has a root node
64+
labeled as specified by `first_label` (defaulting to 0 if not provided). Subtrees from the input
65+
`rooted_trees` are attached to this new root node. Each non-root node, if the `label_attribute`
66+
is provided, has an attribute that indicates the original label of the node in the input tree.
6467
6568
Notes
6669
-----
@@ -86,42 +89,38 @@ def join_trees(rooted_trees, label_attribute=None):
8689
True
8790
8891
"""
89-
if len(rooted_trees) == 0:
92+
if not rooted_trees:
9093
return nx.empty_graph(1)
9194

9295
# Unzip the zipped list of (tree, root) pairs.
9396
trees, roots = zip(*rooted_trees)
9497

95-
# The join of the trees has the same type as the type of the first
96-
# tree.
98+
# The join of the trees has the same type as the type of the first tree.
9799
R = type(trees[0])()
98100

99-
# Relabel the nodes so that their union is the integers starting at 1.
100-
if label_attribute is None:
101-
label_attribute = "_old"
101+
lengths = (len(tree) for tree in trees[:-1])
102+
first_labels = list(accumulate(lengths, initial=first_label + 1))
103+
104+
new_roots = []
105+
for tree, root, first_node in zip(trees, roots, first_labels):
106+
new_root = first_node + list(tree.nodes()).index(root)
107+
new_roots.append(new_root)
108+
109+
# Relabel the nodes so that their union is the integers starting at first_label.
102110
relabel = partial(
103111
nx.convert_node_labels_to_integers, label_attribute=label_attribute
104112
)
105-
lengths = (len(tree) for tree in trees[:-1])
106-
first_labels = chain([0], accumulate(lengths))
107-
trees = [
108-
relabel(tree, first_label=first_label + 1)
113+
new_trees = [
114+
relabel(tree, first_label=first_label)
109115
for tree, first_label in zip(trees, first_labels)
110116
]
111117

112-
# Get the relabeled roots.
113-
roots = [
114-
next(v for v, d in tree.nodes(data=True) if d.get(label_attribute) == root)
115-
for tree, root in zip(trees, roots)
116-
]
117-
118118
# Add all sets of nodes and edges, attributes
119-
for tree in trees:
119+
for tree in new_trees:
120120
R.update(tree)
121121

122-
# Finally, join the subtrees at the root. We know 0 is unused by the
123-
# way we relabeled the subtrees.
124-
R.add_node(0)
125-
R.add_edges_from((0, root) for root in roots)
122+
# Finally, join the subtrees at the root. We know first_label is unused by the way we relabeled the subtrees.
123+
R.add_node(first_label)
124+
R.add_edges_from((first_label, root) for root in new_roots)
126125

127126
return R

networkx/algorithms/tree/tests/test_operations.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from networkx.utils import edges_equal, nodes_equal
55

66

7-
def _check_label_attribute(input_trees, res_tree, label_attribute="_old"):
7+
def _check_custom_label_attribute(input_trees, res_tree, label_attribute):
88
res_attr_dict = nx.get_node_attributes(res_tree, label_attribute)
99
res_attr_set = set(res_attr_dict.values())
10-
input_label = (list(tree[0].nodes()) for tree in input_trees)
10+
input_label = (tree for tree, root in input_trees)
1111
input_label_set = set(chain.from_iterable(input_label))
1212
return res_attr_set == input_label_set
1313

@@ -23,18 +23,31 @@ def test_single():
2323
"""Joining just one tree yields a tree with one more node."""
2424
T = nx.empty_graph(1)
2525
trees = [(T, 0)]
26-
actual = nx.join_trees(trees)
26+
actual_with_label = nx.join_trees(trees, label_attribute="custom_label")
2727
expected = nx.path_graph(2)
28-
assert nodes_equal(list(expected), list(actual))
29-
assert edges_equal(list(expected.edges()), list(actual.edges()))
30-
assert _check_label_attribute(trees, actual)
28+
assert nodes_equal(list(expected), list(actual_with_label))
29+
assert edges_equal(list(expected.edges()), list(actual_with_label.edges()))
3130

3231

3332
def test_basic():
3433
"""Joining multiple subtrees at a root node."""
3534
trees = [(nx.full_rary_tree(2, 2**2 - 1), 0) for i in range(2)]
36-
label_attribute = "old_values"
37-
actual = nx.join_trees(trees, label_attribute)
3835
expected = nx.full_rary_tree(2, 2**3 - 1)
36+
actual = nx.join_trees(trees, label_attribute="old_labels")
3937
assert nx.is_isomorphic(actual, expected)
40-
assert _check_label_attribute(trees, actual, label_attribute)
38+
assert _check_custom_label_attribute(trees, actual, "old_labels")
39+
40+
actual_without_label = nx.join_trees(trees)
41+
assert nx.is_isomorphic(actual_without_label, expected)
42+
# check that no labels were stored
43+
assert all(not data for _, data in actual_without_label.nodes(data=True))
44+
45+
46+
def test_first_label():
47+
"""Test the functionality of the first_label argument."""
48+
T1 = nx.path_graph(3)
49+
T2 = nx.path_graph(2)
50+
actual = nx.join_trees([(T1, 0), (T2, 0)], first_label=10)
51+
expected_nodes = set(range(10, 16))
52+
assert set(actual.nodes()) == expected_nodes
53+
assert set(actual.neighbors(10)) == {11, 14}

0 commit comments

Comments
 (0)