Skip to content

Commit 3b497b3

Browse files
authored
test: Fix test_mtl_backward warnings (#404)
* Before this, we had UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
1 parent 3ef5c3f commit 3b497b3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/unit/autojac/test_mtl_backward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def test_tasks_params_overlap():
496496
assert_close(p1.grad, f * p12)
497497
assert_close(p12.grad, f * p1 + f * p2)
498498

499-
J = tensor_([[-p1 * p12, p1 * p12], [-p2 * p12, p2 * p12]])
499+
J = tensor_([[-8.0, 8.0], [-12.0, 12.0]])
500500
assert_close(p0.grad, aggregator(J))
501501

502502

@@ -515,7 +515,7 @@ def test_tasks_params_are_the_same():
515515

516516
assert_close(p1.grad, f + 1)
517517

518-
J = tensor_([[-p1, p1], [-1.0, 1.0]])
518+
J = tensor_([[-2.0, 2.0], [-1.0, 1.0]])
519519
assert_close(p0.grad, aggregator(J))
520520

521521

@@ -539,7 +539,7 @@ def test_task_params_is_subset_of_other_task_params():
539539
assert_close(p2.grad, y1)
540540
assert_close(p1.grad, p2 * f + f)
541541

542-
J = tensor_([[-p1, p1], [-p1 * p2, p1 * p2]])
542+
J = tensor_([[-2.0, 2.0], [-6.0, 6.0]])
543543
assert_close(p0.grad, aggregator(J))
544544

545545

0 commit comments

Comments
 (0)