We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
_check_losses_are_scalar
1 parent 08cef61 commit b381581Copy full SHA for b381581
src/torchjd/_autojac/_mtl_backward.py
@@ -190,7 +190,7 @@ def _create_task_transform(
190
return backward_task
191
192
193
-def _check_losses_are_scalar(losses: Sequence[Tensor]) -> None:
+def _check_losses_are_scalar(losses: Iterable[Tensor]) -> None:
194
for loss in losses:
195
if loss.ndim > 0:
196
raise ValueError("`losses` should contain only scalars.")
0 commit comments