We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
raise_non_differentiable_error
1 parent 75ad2ae commit 9b3cc37Copy full SHA for 9b3cc37
src/torchjd/aggregation/_utils/non_differentiable.py
@@ -6,5 +6,5 @@ def __init__(self, module: nn.Module):
6
super().__init__(f"Trying to differentiate through {module}, which is not differentiable.")
7
8
9
-def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...]) -> None:
+def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...] | Tensor) -> None:
10
raise NonDifferentiableError(module)
0 commit comments