Skip to content

Commit bd7af02

Browse files
committed
refactor: test_clip_grad_norm
1 parent 6eb1d6d commit bd7af02

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def test_clip_grad_norm():
3232
x = torch.arange(0, 10, dtype=torch.float32, requires_grad=True)
3333
x.grad = torch.arange(0, 10, dtype=torch.float32)
3434

35-
np.testing.assert_approx_equal(float(clip_grad_norm(x)), 16.881943016134134, significant=4)
36-
np.testing.assert_approx_equal(float(clip_grad_norm(x, max_norm=2)), 16.881943016134134, significant=4)
35+
np.testing.assert_approx_equal(clip_grad_norm(x), 16.881943016134134, significant=4)
36+
np.testing.assert_approx_equal(clip_grad_norm(x, max_norm=2), 16.881943016134134, significant=4)
3737

3838

3939
def test_unit_norm():

0 commit comments

Comments
 (0)