Skip to content

Commit ce56bdb

Browse files
committed
minor changes + test data addition
1 parent 615a4a9 commit ce56bdb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+14
-6
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def compute(self):
5050

5151

5252
class BalancedAccuracy(torchmetrics.Metric):
53+
"""Balanced Accuracy = (TPR + TNR) / 2 = ( TP/(TP + FN) + (TN)/(TN + FP) ) / 2
54+
55+
This metric computes the balanced accuracy, which is the average of true positive rate (TPR)
56+
and true negative rate (TNR). It is useful for imbalanced datasets where the classes are not
57+
represented equally.
58+
"""
59+
5360
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
5461
super().__init__(dist_sync_on_step=dist_sync_on_step)
5562

tests/testCustomBalancedAccuracyMetric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import random
66

77

8-
class TestCustomMacroF1Metric(unittest.TestCase):
8+
class TestCustomBalancedAccuracyMetric(unittest.TestCase):
99

1010
@classmethod
1111
def setUpClass(cls) -> None:
@@ -31,7 +31,7 @@ def test_iterative_vs_single_call_approach(self):
3131

3232
def test_metric_against_realistic_data(self):
3333
"""Test the custom metric against the standard on realistic data"""
34-
directory_path = "CheBIOver100_test"
34+
directory_path = os.path.join("tests", "test_data", "CheBIOver100_test")
3535
abs_path = os.path.join(os.getcwd(), directory_path)
3636
print(f"Checking data from - {abs_path}")
3737
num_of_files = len(os.listdir(abs_path)) // 2

tests/testCustomMacroF1Metric.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ class TestCustomMacroF1Metric(unittest.TestCase):
1212
def setUpClass(cls) -> None:
1313
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
1414

15-
@unittest.expectedFailure
1615
def test_all_predictions_are_1_half_labels_are_1(self):
1716
"""Test custom metric against standard metric for the scenario where all prediction are 1 but only half of
1817
the labels are 1"""
@@ -54,8 +53,10 @@ def test_all_predictions_are_1_half_labels_are_1(self):
5453
# recall = [1, 1, 1, 1, 1] / [1, 1, 1, 1, 1] = [1, 1, 1, 1, 1]
5554
# classwise_f1 = [2, 2, 2, 2, 2] / [2, 2, 2, 2, 2] = [1, 1, 1, 1, 1]
5655
# mean = 5/5 = 1 (because of masking we averaging with across positive labels only)
57-
58-
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
56+
self.assertAlmostEqual(macro_f1_custom_score, 1, places=4)
57+
self.assertNotAlmostEqual(
58+
macro_f1_custom_score, macro_f1_standard_score, places=4
59+
)
5960

6061
def test_all_labels_are_1_half_predictions_are_1(self):
6162
"""Test custom metric against standard metric for the scenario where all labels are 1 but only half of
@@ -98,7 +99,7 @@ def test_iterative_vs_single_call_approach(self):
9899
@unittest.expectedFailure
99100
def test_metric_against_realistic_data(self):
100101
"""Test the custom metric against the standard on realistic data"""
101-
directory_path = "CheBIOver100_test"
102+
directory_path = os.path.join("tests", "test_data", "CheBIOver100_test")
102103
abs_path = os.path.join(os.getcwd(), directory_path)
103104
print(f"Checking data from - {abs_path}")
104105
num_of_files = len(os.listdir(abs_path)) // 2
500 KB
Binary file not shown.
500 KB
Binary file not shown.
500 KB
Binary file not shown.
500 KB
Binary file not shown.
500 KB
Binary file not shown.
500 KB
Binary file not shown.
500 KB
Binary file not shown.

0 commit comments

Comments
 (0)