Skip to content

Commit 6d799a3

Browse files
authored
Fix(aggregation): Fix some aggregator __repr__ (#362)
* Fix NashMTL.__repr__ to also include max_norm, update_weights_every and optim_niter. Adapt its representation test. * Fix GradDrop.__repr__ to also include f. Adapt its representation test using a regex. * Add _epsilon to MGDA * Add _weights to Constant
1 parent 9b3cc37 commit 6d799a3

File tree

6 files changed

+23
-12
lines changed

6 files changed

+23
-12
lines changed

src/torchjd/aggregation/_constant.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ class Constant(WeightedAggregator):
2929

3030
def __init__(self, weights: Tensor):
3131
super().__init__(weighting=_ConstantWeighting(weights=weights))
32+
self._weights = weights
3233

3334
def __repr__(self) -> str:
34-
return f"{self.__class__.__name__}(weights={repr(self.weighting.weights)})"
35+
return f"{self.__class__.__name__}(weights={repr(self._weights)})"
3536

3637
def __str__(self) -> str:
37-
weights_str = vector_to_str(self.weighting.weights)
38+
weights_str = vector_to_str(self._weights)
3839
return f"{self.__class__.__name__}([{weights_str}])"
3940

4041

src/torchjd/aggregation/_graddrop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
8282
)
8383

8484
def __repr__(self) -> str:
85-
return f"{self.__class__.__name__}(leak={repr(self.leak)})"
85+
return f"{self.__class__.__name__}(f={repr(self.f)}, leak={repr(self.leak)})"
8686

8787
def __str__(self) -> str:
8888
if self.leak is None:

src/torchjd/aggregation/_mgda.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@ class MGDA(GramianWeightedAggregator):
3333

3434
def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
3535
super().__init__(_MGDAWeighting(epsilon=epsilon, max_iters=max_iters))
36+
self._epsilon = epsilon
3637
self._max_iters = max_iters
3738

3839
def __repr__(self) -> str:
39-
return (
40-
f"{self.__class__.__name__}(epsilon={self.weighting.weighting.epsilon}, max_iters="
41-
f"{self._max_iters})"
42-
)
40+
return f"{self.__class__.__name__}(epsilon={self._epsilon}, max_iters={self._max_iters})"
4341

4442

4543
class _MGDAWeighting(Weighting[PSDMatrix]):

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def __init__(
9696
optim_niter=optim_niter,
9797
)
9898
)
99+
self._n_tasks = n_tasks
100+
self._max_norm = max_norm
101+
self._update_weights_every = update_weights_every
102+
self._optim_niter = optim_niter
99103

100104
# This prevents considering the computed weights as constant w.r.t. the matrix.
101105
self.register_full_backward_pre_hook(raise_non_differentiable_error)
@@ -105,7 +109,10 @@ def reset(self) -> None:
105109
self.weighting.reset()
106110

107111
def __repr__(self) -> str:
108-
return f"{self.__class__.__name__}(n_tasks={self.weighting.n_tasks})"
112+
return (
113+
f"{self.__class__.__name__}(n_tasks={self._n_tasks}, max_norm={self._max_norm}, "
114+
f"update_weights_every={self._update_weights_every}, optim_niter={self._optim_niter})"
115+
)
109116

110117

111118
class _NashMTLWeighting(Weighting[Matrix]):

tests/unit/aggregation/test_graddrop.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from contextlib import nullcontext as does_not_raise
23

34
import torch
@@ -69,9 +70,13 @@ def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: Exc
6970

7071
def test_representations():
7172
A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu"))
72-
assert repr(A) == "GradDrop(leak=tensor([0., 1.]))"
73+
assert re.match(
74+
r"GradDrop\(f=<function _identity at 0x[0-9a-fA-F]+>, leak=tensor\(\[0\., 1\.\]\)\)",
75+
repr(A),
76+
)
77+
7378
assert str(A) == "GradDrop([0., 1.])"
7479

7580
A = GradDrop()
76-
assert repr(A) == "GradDrop(leak=None)"
81+
assert re.match(r"GradDrop\(f=<function _identity at 0x[0-9a-fA-F]+>, leak=None\)", repr(A))
7782
assert str(A) == "GradDrop"

tests/unit/aggregation/test_nash_mtl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,6 @@ def test_nash_mtl_reset():
5555

5656

5757
def test_representations():
58-
A = NashMTL(n_tasks=2)
59-
assert repr(A) == "NashMTL(n_tasks=2)"
58+
A = NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)
59+
assert repr(A) == "NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)"
6060
assert str(A) == "NashMTL"

0 commit comments

Comments
 (0)