Skip to content

Commit 9245a02

Browse files
committed
Improve test_edge_registry
- Add test_all_edges_are_leaves2 - Simplify tests
1 parent d6b76fb commit 9245a02

File tree

1 file changed

+41
-49
lines changed

1 file changed

+41
-49
lines changed

tests/unit/autogram/test_edge_registry.py

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,51 @@
44
from torchjd.autogram._edge_registry import EdgeRegistry
55

66

7-
def test_all_edges_are_leaves():
7+
def test_all_edges_are_leaves1():
8+
"""Tests that get_leaf_edges works correctly when all edges are already leaves."""
9+
810
a = randn_([3, 4], requires_grad=True)
911
b = randn_([4], requires_grad=True)
1012
c = randn_([3], requires_grad=True)
1113

1214
d = (a @ b) + c
1315

1416
edge_registry = EdgeRegistry()
17+
for tensor in [a, b, c]:
18+
edge_registry.register(get_gradient_edge(tensor))
1519

16-
registered_edges = {
17-
a: get_gradient_edge(a),
18-
b: get_gradient_edge(b),
19-
c: get_gradient_edge(c),
20-
}
20+
expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c]}
21+
leaves = edge_registry.get_leaf_edges({get_gradient_edge(d)}, set())
22+
assert leaves == expected_leaves
2123

22-
expected_leaves = {registered_edges[a], registered_edges[b], registered_edges[c]}
2324

24-
for edge in registered_edges.values():
25-
edge_registry.register(edge)
25+
def test_all_edges_are_leaves2():
26+
"""
27+
Tests that get_leaf_edges works correctly when all edges are already leaves of the graph of
28+
edges leading to them, but are not leaves of the autograd graph.
29+
"""
2630

27-
leaves = edge_registry.get_leaf_edges({get_gradient_edge(d)}, set())
31+
a = randn_([3, 4], requires_grad=True)
32+
b = randn_([4], requires_grad=True)
33+
c = randn_([4], requires_grad=True)
34+
d = randn_([4], requires_grad=True)
35+
36+
e = a * b
37+
f = e + c
38+
g = f + d
39+
40+
edge_registry = EdgeRegistry()
41+
for tensor in [e, g]:
42+
edge_registry.register(get_gradient_edge(tensor))
2843

44+
expected_leaves = {get_gradient_edge(tensor) for tensor in [e, g]}
45+
leaves = edge_registry.get_leaf_edges({get_gradient_edge(e), get_gradient_edge(g)}, set())
2946
assert leaves == expected_leaves
3047

3148

3249
def test_some_edges_are_not_leaves1():
50+
"""Tests that get_leaf_edges works correctly when some edges are leaves and some are not."""
51+
3352
a = randn_([3, 4], requires_grad=True)
3453
b = randn_([4], requires_grad=True)
3554
c = randn_([4], requires_grad=True)
@@ -40,33 +59,21 @@ def test_some_edges_are_not_leaves1():
4059
g = f + d
4160

4261
edge_registry = EdgeRegistry()
62+
for tensor in [a, b, c, d, e, f, g]:
63+
edge_registry.register(get_gradient_edge(tensor))
4364

44-
registered_edges = {
45-
a: get_gradient_edge(a),
46-
b: get_gradient_edge(b),
47-
c: get_gradient_edge(c),
48-
d: get_gradient_edge(d),
49-
e: get_gradient_edge(e),
50-
f: get_gradient_edge(f),
51-
g: get_gradient_edge(g),
52-
}
53-
54-
expected_leaves = {
55-
registered_edges[a],
56-
registered_edges[b],
57-
registered_edges[c],
58-
registered_edges[d],
59-
}
60-
61-
for edge in registered_edges.values():
62-
edge_registry.register(edge)
63-
65+
expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c, d]}
6466
leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)}, set())
65-
6667
assert leaves == expected_leaves
6768

6869

6970
def test_some_edges_are_not_leaves2():
71+
"""
72+
Tests that get_leaf_edges works correctly when some edges are leaves and some are not. This
73+
time, not all tensors in the graph are registered so not all leavese in the graph have to be
74+
returned.
75+
"""
76+
7077
a = randn_([3, 4], requires_grad=True)
7178
b = randn_([4], requires_grad=True)
7279
c = randn_([4], requires_grad=True)
@@ -77,24 +84,9 @@ def test_some_edges_are_not_leaves2():
7784
g = f + d
7885

7986
edge_registry = EdgeRegistry()
87+
for tensor in [a, c, d, e, g]:
88+
edge_registry.register(get_gradient_edge(tensor))
8089

81-
registered_edges = {
82-
a: get_gradient_edge(a),
83-
c: get_gradient_edge(c),
84-
d: get_gradient_edge(d),
85-
e: get_gradient_edge(e),
86-
g: get_gradient_edge(g),
87-
}
88-
89-
expected_leaves = {
90-
registered_edges[a],
91-
registered_edges[c],
92-
registered_edges[d],
93-
}
94-
95-
for edge in registered_edges.values():
96-
edge_registry.register(edge)
97-
90+
expected_leaves = {get_gradient_edge(tensor) for tensor in [a, c, d]}
9891
leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)}, set())
99-
10092
assert leaves == expected_leaves

0 commit comments

Comments
 (0)