Skip to content

Commit 9b3cc37

Browse files
authored
fix(aggregation): Fix raise_non_differentiable_error type (#361)
1 parent 75ad2ae commit 9b3cc37

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/torchjd/aggregation/_utils/non_differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ def __init__(self, module: nn.Module):
66
super().__init__(f"Trying to differentiate through {module}, which is not differentiable.")
77

88

9-
def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...]) -> None:
9+
def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...] | Tensor) -> None:
1010
raise NonDifferentiableError(module)

0 commit comments

Comments
 (0)