Skip to content

Commit 40bd21b

Browse files
committed
update: test_normalized_gradient
1 parent 831deb2 commit 40bd21b

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tests/test_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,21 @@ def test_has_overflow():
3636

3737
def test_normalized_gradient():
3838
x = torch.arange(0, 10, dtype=torch.float32)
39+
normalize_gradient(x)
3940

4041
np.testing.assert_allclose(
41-
normalize_gradient(x).numpy(),
42+
x.numpy(),
4243
np.asarray([0.0000, 0.3303, 0.6606, 0.9909, 1.3212, 1.6514, 1.9817, 2.3120, 2.6423, 2.9726]),
4344
rtol=1e-4,
4445
atol=1e-4,
4546
)
4647

48+
x = torch.arange(0, 10, dtype=torch.float32)
49+
normalize_gradient(x.view(1, 10), use_channels=True)
50+
4751
np.testing.assert_allclose(
48-
normalize_gradient(x.view(1, 10), use_channels=True).numpy(),
49-
np.asarray([[0.0000, 0.3303, 0.6606, 0.9909, 1.3212, 1.6514, 1.9817, 2.3120, 2.6423, 2.9726]]),
52+
x.numpy(),
53+
np.asarray([0.0000, 0.3303, 0.6606, 0.9909, 1.3212, 1.6514, 1.9817, 2.3120, 2.6423, 2.9726]),
5054
rtol=1e-4,
5155
atol=1e-4,
5256
)

0 commit comments

Comments
 (0)