Skip to content

Commit 9358786

Browse files
committed
update: test_unit_norm
1 parent bd7af02 commit 9358786

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/test_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,9 @@ def test_clip_grad_norm():
3737

3838

3939
def test_unit_norm():
40-
pass
40+
x = torch.arange(0, 10, dtype=torch.float32)
41+
42+
np.testing.assert_approx_equal(unit_norm(x).numpy(), 16.8819, significant=4)
43+
np.testing.assert_approx_equal(unit_norm(x.view(1, 10)).numpy(), 16.8819, significant=4)
44+
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1)).numpy(), 16.8819, significant=4)
45+
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1, 1, 1)).numpy(), 16.8819, significant=4)

0 commit comments

Comments
 (0)