44from torchjd .autogram ._edge_registry import EdgeRegistry
55
66
7- def test_all_edges_are_leaves ():
7+ def test_all_edges_are_leaves1 ():
8+ """Tests that get_leaf_edges works correctly when all edges are already leaves."""
9+
810 a = randn_ ([3 , 4 ], requires_grad = True )
911 b = randn_ ([4 ], requires_grad = True )
1012 c = randn_ ([3 ], requires_grad = True )
1113
1214 d = (a @ b ) + c
1315
1416 edge_registry = EdgeRegistry ()
17+ for tensor in [a , b , c ]:
18+ edge_registry .register (get_gradient_edge (tensor ))
1519
16- registered_edges = {
17- a : get_gradient_edge (a ),
18- b : get_gradient_edge (b ),
19- c : get_gradient_edge (c ),
20- }
20+ expected_leaves = {get_gradient_edge (tensor ) for tensor in [a , b , c ]}
21+ leaves = edge_registry .get_leaf_edges ({get_gradient_edge (d )}, set ())
22+ assert leaves == expected_leaves
2123
22- expected_leaves = {registered_edges [a ], registered_edges [b ], registered_edges [c ]}
2324
24- for edge in registered_edges .values ():
25- edge_registry .register (edge )
25+ def test_all_edges_are_leaves2 ():
26+ """
27+ Tests that get_leaf_edges works correctly when all edges are already leaves of the graph of
28+ edges leading to them, but are not leaves of the autograd graph.
29+ """
2630
27- leaves = edge_registry .get_leaf_edges ({get_gradient_edge (d )}, set ())
31+ a = randn_ ([3 , 4 ], requires_grad = True )
32+ b = randn_ ([4 ], requires_grad = True )
33+ c = randn_ ([4 ], requires_grad = True )
34+ d = randn_ ([4 ], requires_grad = True )
35+
36+ e = a * b
37+ f = e + c
38+ g = f + d
39+
40+ edge_registry = EdgeRegistry ()
41+ for tensor in [e , g ]:
42+ edge_registry .register (get_gradient_edge (tensor ))
2843
44+ expected_leaves = {get_gradient_edge (tensor ) for tensor in [e , g ]}
45+ leaves = edge_registry .get_leaf_edges ({get_gradient_edge (e ), get_gradient_edge (g )}, set ())
2946 assert leaves == expected_leaves
3047
3148
3249def test_some_edges_are_not_leaves1 ():
50+ """Tests that get_leaf_edges works correctly when some edges are leaves and some are not."""
51+
3352 a = randn_ ([3 , 4 ], requires_grad = True )
3453 b = randn_ ([4 ], requires_grad = True )
3554 c = randn_ ([4 ], requires_grad = True )
@@ -40,33 +59,21 @@ def test_some_edges_are_not_leaves1():
4059 g = f + d
4160
4261 edge_registry = EdgeRegistry ()
62+ for tensor in [a , b , c , d , e , f , g ]:
63+ edge_registry .register (get_gradient_edge (tensor ))
4364
44- registered_edges = {
45- a : get_gradient_edge (a ),
46- b : get_gradient_edge (b ),
47- c : get_gradient_edge (c ),
48- d : get_gradient_edge (d ),
49- e : get_gradient_edge (e ),
50- f : get_gradient_edge (f ),
51- g : get_gradient_edge (g ),
52- }
53-
54- expected_leaves = {
55- registered_edges [a ],
56- registered_edges [b ],
57- registered_edges [c ],
58- registered_edges [d ],
59- }
60-
61- for edge in registered_edges .values ():
62- edge_registry .register (edge )
63-
65+ expected_leaves = {get_gradient_edge (tensor ) for tensor in [a , b , c , d ]}
6466 leaves = edge_registry .get_leaf_edges ({get_gradient_edge (g )}, set ())
65-
6667 assert leaves == expected_leaves
6768
6869
6970def test_some_edges_are_not_leaves2 ():
71+ """
72+ Tests that get_leaf_edges works correctly when some edges are leaves and some are not. This
73+ time, not all tensors in the graph are registered so not all leavese in the graph have to be
74+ returned.
75+ """
76+
7077 a = randn_ ([3 , 4 ], requires_grad = True )
7178 b = randn_ ([4 ], requires_grad = True )
7279 c = randn_ ([4 ], requires_grad = True )
@@ -77,24 +84,9 @@ def test_some_edges_are_not_leaves2():
7784 g = f + d
7885
7986 edge_registry = EdgeRegistry ()
87+ for tensor in [a , c , d , e , g ]:
88+ edge_registry .register (get_gradient_edge (tensor ))
8089
81- registered_edges = {
82- a : get_gradient_edge (a ),
83- c : get_gradient_edge (c ),
84- d : get_gradient_edge (d ),
85- e : get_gradient_edge (e ),
86- g : get_gradient_edge (g ),
87- }
88-
89- expected_leaves = {
90- registered_edges [a ],
91- registered_edges [c ],
92- registered_edges [d ],
93- }
94-
95- for edge in registered_edges .values ():
96- edge_registry .register (edge )
97-
90+ expected_leaves = {get_gradient_edge (tensor ) for tensor in [a , c , d ]}
9891 leaves = edge_registry .get_leaf_edges ({get_gradient_edge (g )}, set ())
99-
10092 assert leaves == expected_leaves
0 commit comments