Skip to content

Commit 9182b4a

Browse files
authored
refactor: Add casts when necessary (#372)
* Add cast in get_leaf_tensors * Add cast of w_opt to np.ndarray in _CAGradWeighting.forward * Add cast to type[TensorDict] in _least_common_ancestor
1 parent 3e3e0e6 commit 9182b4a

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

src/torchjd/_autojac/_transform/_tensor_dict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar
1+
from typing import TypeVar, cast
22

33
from torch import Tensor
44

@@ -131,7 +131,7 @@ def _least_common_ancestor(first: type[TensorDict], second: type[TensorDict]) ->
131131
output = TensorDict
132132
for candidate_type in first_mro:
133133
if issubclass(second, candidate_type):
134-
output = candidate_type
134+
output = cast(type[TensorDict], candidate_type)
135135
break
136136
return output
137137

src/torchjd/_autojac/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import deque
22
from collections.abc import Iterable, Sequence
3+
from typing import cast
34

45
from torch import Tensor
56
from torch.autograd.graph import Node
@@ -48,8 +49,8 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O
4849
raise ValueError("All `excluded` tensors should have a `grad_fn`.")
4950

5051
accumulate_grads = _get_descendant_accumulate_grads(
51-
roots=OrderedSet([tensor.grad_fn for tensor in tensors]),
52-
excluded_nodes={tensor.grad_fn for tensor in excluded},
52+
roots=cast(OrderedSet[Node], OrderedSet([tensor.grad_fn for tensor in tensors])),
53+
excluded_nodes=cast(set[Node], {tensor.grad_fn for tensor in excluded}),
5354
)
5455

5556
# accumulate_grads contains instances of AccumulateGrad, which contain a `variable` field.

src/torchjd/aggregation/_cagrad.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import cast
2+
13
from ._utils.check_dependencies import check_dependencies_are_installed
24
from ._weighting_bases import PSDMatrix, Weighting
35

@@ -101,7 +103,7 @@ def forward(self, gramian: Tensor) -> Tensor:
101103
problem = cp.Problem(objective=cp.Minimize(cost), constraints=[w >= 0, cp.sum(w) == 1])
102104

103105
problem.solve(cp.CLARABEL)
104-
w_opt = w.value
106+
w_opt = cast(np.ndarray, w.value)
105107

106108
g_w_norm = np.linalg.norm(reduced_array.T @ w_opt, 2).item()
107109
if g_w_norm >= self.norm_eps:

0 commit comments

Comments
 (0)