Skip to content

Commit f677df5

Browse files
authored
docs: Fix monitoring example (#358)
* Rename print_similarity_with_gd to print_gd_similarity * Change type hint of print_gd_similarity inputs to tuple[torch.Tensor, ...] * Reorder imports in rst to match the test
1 parent 340479e commit f677df5

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

docs/source/examples/monitoring.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ they have a negative inner product).
1919
2020
import torch
2121
from torch.nn import Linear, MSELoss, ReLU, Sequential
22-
from torch.optim import SGD
2322
from torch.nn.functional import cosine_similarity
23+
from torch.optim import SGD
2424
2525
from torchjd import mtl_backward
2626
from torchjd.aggregation import UPGrad
@@ -29,7 +29,7 @@ they have a negative inner product).
2929
"""Prints the extracted weights."""
3030
print(f"Weights: {weights}")
3131
32-
def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None:
32+
def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
3333
"""Prints the cosine similarity between the aggregation and the average gradient."""
3434
matrix = inputs[0]
3535
gd_output = matrix.mean(dim=0)
@@ -50,7 +50,7 @@ they have a negative inner product).
5050
aggregator = UPGrad()
5151
5252
aggregator.weighting.register_forward_hook(print_weights)
53-
aggregator.register_forward_hook(print_similarity_with_gd)
53+
aggregator.register_forward_hook(print_gd_similarity)
5454
5555
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
5656
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task

tests/doc/test_rst.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def print_weights(_, __, weights: torch.Tensor) -> None:
222222
"""Prints the extracted weights."""
223223
print(f"Weights: {weights}")
224224

225-
def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None:
225+
def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
226226
"""Prints the cosine similarity between the aggregation and the average gradient."""
227227
matrix = inputs[0]
228228
gd_output = matrix.mean(dim=0)
@@ -243,7 +243,7 @@ def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.
243243
aggregator = UPGrad()
244244

245245
aggregator.weighting.register_forward_hook(print_weights)
246-
aggregator.register_forward_hook(print_similarity_with_gd)
246+
aggregator.register_forward_hook(print_gd_similarity)
247247

248248
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
249249
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task

0 commit comments

Comments
 (0)