Skip to content

Commit f5b39ca

Browse files
authored
refactor(autojac): Remove _disunite check (#377)
1 parent ba9d78f commit f5b39ca

File tree

2 files changed

+0
-42
lines changed

2 files changed

+0
-42
lines changed

src/torchjd/_autojac/_transform/_aggregate.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,6 @@ def _unite(jacobian_matrices: OrderedDict[Tensor, Tensor]) -> Tensor:
115115
def _disunite(
116116
united_gradient_vector: Tensor, jacobian_matrices: OrderedDict[Tensor, Tensor]
117117
) -> GradientVectors:
118-
expected_length = sum([matrix.shape[1] for matrix in jacobian_matrices.values()])
119-
if len(united_gradient_vector) != expected_length:
120-
raise ValueError(
121-
"Parameter `united_gradient_vector` should be a vector with length equal to the sum"
122-
"of the numbers of columns in the jacobian matrices. Found"
123-
f"`len(united_gradient_vector) = {len(united_gradient_vector)}` and the sum of the "
124-
f"numbers of columns in the jacobian matrices is {expected_length}."
125-
)
126-
127118
gradient_vectors = {}
128119
start = 0
129120
for key, jacobian_matrix in jacobian_matrices.items():

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import math
2-
from collections import OrderedDict
32

43
import torch
54
from pytest import mark, raises
6-
from torch import Tensor
75
from unit.conftest import DEVICE
86

97
from torchjd._autojac._transform import (
@@ -72,37 +70,6 @@ def test_aggregate_matrices_empty_dict():
7270
assert len(gradient_vectors) == 0
7371

7472

75-
@mark.parametrize(
76-
["united_gradient_vector", "jacobian_matrices"],
77-
[
78-
(
79-
torch.ones(10),
80-
{ # Total number of parameters according to the united gradient vector: 10
81-
torch.ones(5): torch.ones(2, 5),
82-
torch.ones(4): torch.ones(2, 4),
83-
},
84-
), # Total number of parameters according to the jacobian matrices: 9
85-
(
86-
torch.ones(10),
87-
{ # Total number of parameters according to the united gradient vector: 10
88-
torch.ones(5): torch.ones(2, 5),
89-
torch.ones(3): torch.ones(2, 3),
90-
torch.ones(3): torch.ones(2, 3),
91-
},
92-
), # Total number of parameters according to the jacobian matrices: 11
93-
],
94-
)
95-
def test_disunite_wrong_vector_length(
96-
united_gradient_vector: Tensor, jacobian_matrices: dict[Tensor, Tensor]
97-
):
98-
"""
99-
Tests that the _disunite method raises a ValueError when used on vectors of the wrong length.
100-
"""
101-
102-
with raises(ValueError):
103-
_AggregateMatrices._disunite(united_gradient_vector, OrderedDict(jacobian_matrices))
104-
105-
10673
def test_matrixify():
10774
"""Tests that the Matrixify transform correctly creates matrices from the jacobians."""
10875

0 commit comments

Comments
 (0)