Skip to content

Commit 615a4a9

Browse files
committed
black format
1 parent a4c970d commit 615a4a9

File tree

2 files changed

+47
-26
lines changed

2 files changed

+47
-26
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def compute(self):
4848
classwise_f1 = classwise_f1.nan_to_num()
4949
return torch.mean(classwise_f1)
5050

51+
5152
class BalancedAccuracy(torchmetrics.Metric):
5253
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
5354
super().__init__(dist_sync_on_step=dist_sync_on_step)
@@ -79,14 +80,22 @@ def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
7980
self.threshold = threshold
8081

8182
def update(self, preds: torch.Tensor, labels: torch.Tensor):
82-
""""Update the TPs, TNs ,FPs and FNs """
83+
"""Update the TPs, TNs ,FPs and FNs"""
8384

8485
# Size: Batch_size x Num_of_Classes;
8586
# summing over 1st dimension (dim=0), gives us the True positives per class
86-
tps = torch.sum(torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0)
87-
fps = torch.sum(torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0)
88-
tns = torch.sum(torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0)
89-
fns = torch.sum(torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0)
87+
tps = torch.sum(
88+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
89+
)
90+
fps = torch.sum(
91+
torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0
92+
)
93+
tns = torch.sum(
94+
torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0
95+
)
96+
fns = torch.sum(
97+
torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0
98+
)
9099

91100
# Size: Num_of_Classes;
92101
self.true_positives += tps
@@ -104,4 +113,4 @@ def compute(self):
104113
tnr = tnr.nan_to_num()
105114

106115
balanced_acc = (tpr + tnr) / 2
107-
return torch.mean(balanced_acc)
116+
return torch.mean(balanced_acc)

tests/testCustomBalancedAccuracyMetric.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@ def test_iterative_vs_single_call_approach(self):
1515
"""Test the custom metric implementation in update fashion approach against
1616
the single call approach"""
1717

18-
preds = torch.tensor([[1, 1, 0, 1],
19-
[1, 0, 1, 1],
20-
[0, 1, 0, 1]])
21-
label = torch.tensor([[0, 0, 0, 0],
22-
[0, 0, 1, 1],
23-
[0, 1, 0, 1]])
18+
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
19+
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
2420

2521
num_labels = label.shape[1]
2622
iterative_custom_metric = BalancedAccuracy(num_labels=num_labels)
@@ -41,34 +37,40 @@ def test_metric_against_realistic_data(self):
4137
num_of_files = len(os.listdir(abs_path)) // 2
4238

4339
# load single file to get the num of labels for metric class instantiation
44-
labels = torch.load(f'{directory_path}/labels{0:03d}.pt', map_location=torch.device(self.device))
40+
labels = torch.load(
41+
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
42+
)
4543
num_labels = labels.shape[1]
4644
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
4745

4846
for i in range(num_of_files):
49-
labels = torch.load(f'{directory_path}/labels{i:03d}.pt', map_location=torch.device(self.device))
50-
preds = torch.load(f'{directory_path}/preds{i:03d}.pt', map_location=torch.device(self.device))
47+
labels = torch.load(
48+
f"{directory_path}/labels{i:03d}.pt",
49+
map_location=torch.device(self.device),
50+
)
51+
preds = torch.load(
52+
f"{directory_path}/preds{i:03d}.pt",
53+
map_location=torch.device(self.device),
54+
)
5155
balanced_acc_custom.update(preds, labels)
5256

5357
balanced_acc_custom_score = balanced_acc_custom.compute().item()
5458
print(f"Balanced Accuracy for realistic data: {balanced_acc_custom_score}")
5559

5660
def test_case_when_few_class_has_no_labels(self):
5761
"""Test custom metric against standard metric for the scenario where some class has no labels"""
58-
preds = torch.tensor([[1, 1, 0, 1],
59-
[1, 0, 1, 1],
60-
[0, 1, 0, 1]])
61-
label = torch.tensor([[0, 0, 0, 0], # no labels
62-
[0, 0, 1, 1],
63-
[0, 1, 0, 1]])
62+
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
63+
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]]) # no labels
6464

6565
# tp = [0, 1, 1, 2], fp = [2, 1, 0, 1], tn = [1, 1, 2, 0], fn = [0, 0, 0, 0]
6666
# tpr = [0, 1, 1, 2] / ([0, 1, 1, 2] + [0, 0, 0, 0]) = [0, 1, 1, 1]
6767
# tnr = [1, 1, 2, 0] / ([1, 1, 2, 0] + [2, 1, 0, 1]) = [0.33333, 0.5, 1, 0]
6868
# balanced_accuracy = ([0, 1, 1, 1] + [0.33333, 0.5, 1, 0]) / 2 = ([0.16666667, 0.75, 1, 0.5]
6969
# mean bal accuracy = 0.6041666666666666
7070

71-
balanced_acc_score = self.__get_custom_metric_score(preds, label, label.shape[1])
71+
balanced_acc_score = self.__get_custom_metric_score(
72+
preds, label, label.shape[1]
73+
)
7274

7375
self.assertAlmostEqual(balanced_acc_score, 0.6041666666, places=4)
7476

@@ -78,7 +80,10 @@ def test_all_predictions_are_1_half_labels_are_1(self):
7880
preds = torch.ones((1, 900), dtype=torch.int)
7981
label = torch.ones((1, 900), dtype=torch.int)
8082

81-
mask = [[True] * (label.size(1) // 2) + [False] * (label.size(1) - (label.size(1) // 2))]
83+
mask = [
84+
[True] * (label.size(1) // 2)
85+
+ [False] * (label.size(1) - (label.size(1) // 2))
86+
]
8287
random.shuffle(mask[0])
8388
label[torch.tensor(mask)] = 0
8489

@@ -88,7 +93,9 @@ def test_all_predictions_are_1_half_labels_are_1(self):
8893
# tnr = tn / (tn + fp) = [0, 0, 0, 0]
8994
# balanced accuracy = 1 / 4 = 0.25
9095

91-
balanced_acc_custom_score = self.__get_custom_metric_score(preds, label, label.shape[1])
96+
balanced_acc_custom_score = self.__get_custom_metric_score(
97+
preds, label, label.shape[1]
98+
)
9299
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
93100

94101
def test_all_labels_are_1_half_predictions_are_1(self):
@@ -97,7 +104,10 @@ def test_all_labels_are_1_half_predictions_are_1(self):
97104
preds = torch.ones((1, 900), dtype=torch.int)
98105
label = torch.ones((1, 900), dtype=torch.int)
99106

100-
mask = [[True] * (label.size(1) // 2) + [False] * (label.size(1) - (label.size(1) // 2))]
107+
mask = [
108+
[True] * (label.size(1) // 2)
109+
+ [False] * (label.size(1) - (label.size(1) // 2))
110+
]
101111
random.shuffle(mask[0])
102112
preds[torch.tensor(mask)] = 0
103113

@@ -107,7 +117,9 @@ def test_all_labels_are_1_half_predictions_are_1(self):
107117
# tnr = tn / (tn + fp) = [0, 0, 0, 0]
108118
# balanced accuracy = 1 / 4 = 0.25
109119

110-
balanced_acc_custom_score = self.__get_custom_metric_score(preds, label, label.shape[1])
120+
balanced_acc_custom_score = self.__get_custom_metric_score(
121+
preds, label, label.shape[1]
122+
)
111123
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
112124

113125
@staticmethod

0 commit comments

Comments
 (0)