Skip to content

Commit 3e3e0e6

Browse files
authored
refactor(aggregation): Remove _check_is_finite (#381)
* Remove _check_is_finite * Add changelog entry
1 parent f5b39ca commit 3e3e0e6

File tree

5 files changed

+4
-28
lines changed

5 files changed

+4
-28
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ changes that do not affect the user.
2121
always be imported directly from the `torchjd` package (e.g.
2222
`from torchjd.autojac.mtl_backward import mtl_backward` must be changed to
2323
`from torchjd import mtl_backward`).
24+
- Removed the check that the input Jacobian matrix provided to an aggregator does not contain `nan`,
25+
`inf` or `-inf` values. This check was costly in memory and in time for large matrices so this
26+
should improve performance. However, if the optimization diverges for some reason (for instance
27+
due to a too large learning rate), the resulting exceptions may come from other sources.
2428

2529
### Fixed
2630

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,6 @@ def _check_is_matrix(matrix: Tensor) -> None:
2020
f"{matrix.shape}`."
2121
)
2222

23-
@staticmethod
24-
def _check_is_finite(matrix: Tensor) -> None:
25-
if not matrix.isfinite().all():
26-
raise ValueError(
27-
"Parameter `matrix` should be a tensor of finite elements (no nan, inf or -inf "
28-
f"values). Found `matrix = {matrix}`."
29-
)
30-
3123
@abstractmethod
3224
def forward(self, matrix: Tensor) -> Tensor:
3325
"""Computes the aggregation from the input matrix."""
@@ -69,8 +61,6 @@ def combine(matrix: Tensor, weights: Tensor) -> Tensor:
6961

7062
def forward(self, matrix: Tensor) -> Tensor:
7163
self._check_is_matrix(matrix)
72-
self._check_is_finite(matrix)
73-
7464
weights = self.weighting(matrix)
7565
vector = self.combine(matrix, weights)
7666
return vector

src/torchjd/aggregation/_graddrop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
5555
def forward(self, matrix: Tensor) -> Tensor:
5656
self._check_is_matrix(matrix)
5757
self._check_matrix_has_enough_rows(matrix)
58-
self._check_is_finite(matrix)
5958

6059
if matrix.shape[0] == 0 or matrix.shape[1] == 0:
6160
return torch.zeros(matrix.shape[1], dtype=matrix.dtype, device=matrix.device)

src/torchjd/aggregation/_trimmed_mean.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __init__(self, trim_number: int):
4747
def forward(self, matrix: Tensor) -> Tensor:
4848
self._check_is_matrix(matrix)
4949
self._check_matrix_has_enough_rows(matrix)
50-
self._check_is_finite(matrix)
5150

5251
n_rows = matrix.shape[0]
5352
n_remaining = n_rows - 2 * self.trim_number

tests/unit/aggregation/test_aggregator_bases.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,3 @@
2121
def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext):
2222
with expectation:
2323
Aggregator._check_is_matrix(torch.randn(shape))
24-
25-
26-
@mark.parametrize(
27-
["value", "expectation"],
28-
[
29-
(0.0, does_not_raise()),
30-
(torch.nan, raises(ValueError)),
31-
(torch.inf, raises(ValueError)),
32-
(-torch.inf, raises(ValueError)),
33-
],
34-
)
35-
def test_check_is_finite(value: float, expectation: ExceptionContext):
36-
matrix = torch.ones([5, 5])
37-
matrix[1, 2] = value
38-
with expectation:
39-
Aggregator._check_is_finite(matrix)

0 commit comments

Comments
 (0)