1- import numpy as np
21import pytest
32import torch
43
@@ -70,14 +69,14 @@ def test_focal_loss():
7069
7170
7271@torch .no_grad ()
73- def test_cosine_focal_loss ():
72+ def test_focal_cosine_loss ():
7473 criterion = FocalCosineLoss (alpha = 1.0 , gamma = 2.0 , focal_weight = 0.1 )
7574
7675 y_pred = torch .FloatTensor ([[0.9 , 0.1 , 0.1 ], [0.2 , 0.9 , 0.1 ], [0.2 , 0.1 , 0.1 ]])
7776 y_true = torch .LongTensor ([0 , 1 , 2 ])
7877 loss = criterion (y_pred , y_true )
7978
80- assert float (loss ) == pytest .approx (0.241352 , abs = 1e-6 )
79+ assert float (loss ) == pytest .approx (0.2413520 , abs = 1e-6 )
8180
8281
8382@torch .no_grad ()
@@ -88,7 +87,7 @@ def test_soft_f1_loss():
8887 y_true = torch .FloatTensor ([0.0 ] * 5 + [1.0 ] * 5 )
8988 loss = criterion (y_pred , y_true )
9089
91- np . testing . assert_almost_equal (loss . item (), 0.38905364 )
90+ assert float (loss ) == pytest . approx ( 0.38905364 , abs = 1e-6 )
9291
9392
9493@torch .no_grad ()
@@ -274,7 +273,7 @@ def test_ldam_loss():
274273 y_true = torch .LongTensor ([3 , 0 ])
275274 loss = criterion (y_pred , y_true )
276275
277- np . testing . assert_almost_equal ( loss .item (), 4. 5767049 )
276+ assert loss .item () == pytest . approx ( 0. 5767049, abs = 1e-6 )
278277
279278
280279def test_bi_tempered_log_loss_func ():
@@ -324,7 +323,7 @@ def test_bi_tempered_log_loss(recipe):
324323 loss = criterion (y_pred , y_true )
325324
326325 if reduction == 'none' :
327- torch .testing .assert_allclose (loss , expected_loss , rtol = 1e-4 , atol = 1e-4 )
326+ torch .testing .assert_close (loss , expected_loss , rtol = 1e-4 , atol = 1e-4 )
328327 else :
329328 assert float (loss ) == pytest .approx (expected_loss , abs = 1e-6 )
330329
@@ -345,6 +344,6 @@ def test_binary_bi_tempered_log_loss(recipe):
345344 loss = criterion (y_pred , y_true )
346345
347346 if reduction == 'none' :
348- torch .testing .assert_allclose (loss , expected_loss , rtol = 1e-4 , atol = 1e-4 )
347+ torch .testing .assert_close (loss , expected_loss , rtol = 1e-4 , atol = 1e-4 )
349348 else :
350349 assert float (loss ) == pytest .approx (expected_loss , abs = 1e-6 )
0 commit comments