Skip to content

Commit c22e256

Browse files
committed
Remove parameter excluded from get_leaf_edges
* It was never used * Since roots was actually not modified in-place, this also does excluded = roots, which makes roots actually be modified in place (like the documentation says). So it behave differently than before but in an expected way.
1 parent 9245a02 commit c22e256

File tree

3 files changed

+7
-11
lines changed

3 files changed

+7
-11
lines changed

src/torchjd/autogram/_edge_registry.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,21 @@ def register(self, edge: GradientEdge) -> None:
2323
"""
2424
self._edges.add(edge)
2525

26-
def get_leaf_edges(
27-
self, roots: set[GradientEdge], excluded: set[GradientEdge]
28-
) -> set[GradientEdge]:
26+
def get_leaf_edges(self, roots: set[GradientEdge]) -> set[GradientEdge]:
2927
"""
3028
Compute a minimal subset of edges that yields the same differentiation graph traversal: the
3129
leaf edges. Specifically, this removes edges that are reachable from other edges in the
3230
differentiation graph, avoiding the need to keep gradients in memory for all edges
3331
simultaneously.
3432
3533
:param roots: Roots of the graph traversal. Modified in-place.
36-
:param excluded: GradientEdges that stop graph traversal. Modified in-place.
3734
:returns: Minimal subset of leaf edges.
3835
"""
3936

40-
roots.difference_update(excluded)
4137
nodes_to_traverse = deque((child, root) for root in roots for child in _next_edges(root))
4238
result = {root for root in roots if root in self._edges}
4339

44-
excluded.update(roots)
40+
excluded = roots
4541
while nodes_to_traverse:
4642
node, origin = nodes_to_traverse.popleft()
4743
if node in self._edges:

src/torchjd/autogram/_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def compute_gramian(self, output: Tensor) -> Tensor:
175175
def _compute_square_gramian(self, output: Tensor) -> Tensor:
176176
self._module_hook_manager.gramian_accumulation_phase = True
177177

178-
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}, set()))
178+
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
179179

180180
def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
181181
return torch.autograd.grad(

tests/unit/autogram/test_edge_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_all_edges_are_leaves1():
1818
edge_registry.register(get_gradient_edge(tensor))
1919

2020
expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c]}
21-
leaves = edge_registry.get_leaf_edges({get_gradient_edge(d)}, set())
21+
leaves = edge_registry.get_leaf_edges({get_gradient_edge(d)})
2222
assert leaves == expected_leaves
2323

2424

@@ -42,7 +42,7 @@ def test_all_edges_are_leaves2():
4242
edge_registry.register(get_gradient_edge(tensor))
4343

4444
expected_leaves = {get_gradient_edge(tensor) for tensor in [e, g]}
45-
leaves = edge_registry.get_leaf_edges({get_gradient_edge(e), get_gradient_edge(g)}, set())
45+
leaves = edge_registry.get_leaf_edges({get_gradient_edge(e), get_gradient_edge(g)})
4646
assert leaves == expected_leaves
4747

4848

@@ -63,7 +63,7 @@ def test_some_edges_are_not_leaves1():
6363
edge_registry.register(get_gradient_edge(tensor))
6464

6565
expected_leaves = {get_gradient_edge(tensor) for tensor in [a, b, c, d]}
66-
leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)}, set())
66+
leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)})
6767
assert leaves == expected_leaves
6868

6969

@@ -88,5 +88,5 @@ def test_some_edges_are_not_leaves2():
8888
edge_registry.register(get_gradient_edge(tensor))
8989

9090
expected_leaves = {get_gradient_edge(tensor) for tensor in [a, c, d]}
91-
leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)}, set())
91+
leaves = edge_registry.get_leaf_edges({get_gradient_edge(g)})
9292
assert leaves == expected_leaves

0 commit comments

Comments
 (0)