Skip to content

Commit 4b67d1d

Browse files
authored
style(autojac): Ignore type error in get_leaf_tensors (#371)
1 parent 3a27958 commit 4b67d1d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/torchjd/_autojac/_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O
5151
roots=OrderedSet([tensor.grad_fn for tensor in tensors]),
5252
excluded_nodes={tensor.grad_fn for tensor in excluded},
5353
)
54-
leaves = OrderedSet([g.variable for g in accumulate_grads])
54+
55+
# accumulate_grads contains instances of AccumulateGrad, which contain a `variable` field.
56+
# They cannot be typed as such because AccumulateGrad is not public.
57+
leaves = OrderedSet([g.variable for g in accumulate_grads]) # type: ignore[attr-defined]
5558

5659
return leaves
5760

0 commit comments

Comments
 (0)