Skip to content

Commit d1aedaa

Browse files
committed
refactor: assert_close
1 parent 6225e9a commit d1aedaa

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

tests/test_loss_functions.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import pytest
32
import torch
43

@@ -70,14 +69,14 @@ def test_focal_loss():
7069

7170

7271
@torch.no_grad()
73-
def test_cosine_focal_loss():
72+
def test_focal_cosine_loss():
7473
criterion = FocalCosineLoss(alpha=1.0, gamma=2.0, focal_weight=0.1)
7574

7675
y_pred = torch.FloatTensor([[0.9, 0.1, 0.1], [0.2, 0.9, 0.1], [0.2, 0.1, 0.1]])
7776
y_true = torch.LongTensor([0, 1, 2])
7877
loss = criterion(y_pred, y_true)
7978

80-
assert float(loss) == pytest.approx(0.241352, abs=1e-6)
79+
assert float(loss) == pytest.approx(0.2413520, abs=1e-6)
8180

8281

8382
@torch.no_grad()
@@ -88,7 +87,7 @@ def test_soft_f1_loss():
8887
y_true = torch.FloatTensor([0.0] * 5 + [1.0] * 5)
8988
loss = criterion(y_pred, y_true)
9089

91-
np.testing.assert_almost_equal(loss.item(), 0.38905364)
90+
assert float(loss) == pytest.approx(0.38905364, abs=1e-6)
9291

9392

9493
@torch.no_grad()
@@ -274,7 +273,7 @@ def test_ldam_loss():
274273
y_true = torch.LongTensor([3, 0])
275274
loss = criterion(y_pred, y_true)
276275

277-
np.testing.assert_almost_equal(loss.item(), 4.5767049)
276+
assert loss.item() == pytest.approx(0.5767049, abs=1e-6)
278277

279278

280279
def test_bi_tempered_log_loss_func():
@@ -324,7 +323,7 @@ def test_bi_tempered_log_loss(recipe):
324323
loss = criterion(y_pred, y_true)
325324

326325
if reduction == 'none':
327-
torch.testing.assert_allclose(loss, expected_loss, rtol=1e-4, atol=1e-4)
326+
torch.testing.assert_close(loss, expected_loss, rtol=1e-4, atol=1e-4)
328327
else:
329328
assert float(loss) == pytest.approx(expected_loss, abs=1e-6)
330329

@@ -345,6 +344,6 @@ def test_binary_bi_tempered_log_loss(recipe):
345344
loss = criterion(y_pred, y_true)
346345

347346
if reduction == 'none':
348-
torch.testing.assert_allclose(loss, expected_loss, rtol=1e-4, atol=1e-4)
347+
torch.testing.assert_close(loss, expected_loss, rtol=1e-4, atol=1e-4)
349348
else:
350349
assert float(loss) == pytest.approx(expected_loss, abs=1e-6)

0 commit comments

Comments
 (0)