@@ -15,12 +15,8 @@ def test_iterative_vs_single_call_approach(self):
1515 """Test the custom metric implementation in update fashion approach against
1616 the single call approach"""
1717
18- preds = torch .tensor ([[1 , 1 , 0 , 1 ],
19- [1 , 0 , 1 , 1 ],
20- [0 , 1 , 0 , 1 ]])
21- label = torch .tensor ([[0 , 0 , 0 , 0 ],
22- [0 , 0 , 1 , 1 ],
23- [0 , 1 , 0 , 1 ]])
18+ preds = torch .tensor ([[1 , 1 , 0 , 1 ], [1 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
19+ label = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
2420
2521 num_labels = label .shape [1 ]
2622 iterative_custom_metric = BalancedAccuracy (num_labels = num_labels )
@@ -41,34 +37,40 @@ def test_metric_against_realistic_data(self):
4137 num_of_files = len (os .listdir (abs_path )) // 2
4238
4339 # load single file to get the num of labels for metric class instantiation
44- labels = torch .load (f'{ directory_path } /labels{ 0 :03d} .pt' , map_location = torch .device (self .device ))
40+ labels = torch .load (
41+ f"{ directory_path } /labels{ 0 :03d} .pt" , map_location = torch .device (self .device )
42+ )
4543 num_labels = labels .shape [1 ]
4644 balanced_acc_custom = BalancedAccuracy (num_labels = num_labels )
4745
4846 for i in range (num_of_files ):
49- labels = torch .load (f'{ directory_path } /labels{ i :03d} .pt' , map_location = torch .device (self .device ))
50- preds = torch .load (f'{ directory_path } /preds{ i :03d} .pt' , map_location = torch .device (self .device ))
47+ labels = torch .load (
48+ f"{ directory_path } /labels{ i :03d} .pt" ,
49+ map_location = torch .device (self .device ),
50+ )
51+ preds = torch .load (
52+ f"{ directory_path } /preds{ i :03d} .pt" ,
53+ map_location = torch .device (self .device ),
54+ )
5155 balanced_acc_custom .update (preds , labels )
5256
5357 balanced_acc_custom_score = balanced_acc_custom .compute ().item ()
5458 print (f"Balanced Accuracy for realistic data: { balanced_acc_custom_score } " )
5559
5660 def test_case_when_few_class_has_no_labels (self ):
5761 """Test custom metric against standard metric for the scenario where some class has no labels"""
58- preds = torch .tensor ([[1 , 1 , 0 , 1 ],
59- [1 , 0 , 1 , 1 ],
60- [0 , 1 , 0 , 1 ]])
61- label = torch .tensor ([[0 , 0 , 0 , 0 ], # no labels
62- [0 , 0 , 1 , 1 ],
63- [0 , 1 , 0 , 1 ]])
62+ preds = torch .tensor ([[1 , 1 , 0 , 1 ], [1 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
63+ label = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]]) # no labels
6464
6565 # tp = [0, 1, 1, 2], fp = [2, 1, 0, 1], tn = [1, 1, 2, 0], fn = [0, 0, 0, 0]
6666 # tpr = [0, 1, 1, 2] / ([0, 1, 1, 2] + [0, 0, 0, 0]) = [0, 1, 1, 1]
6767 # tnr = [1, 1, 2, 0] / ([1, 1, 2, 0] + [2, 1, 0, 1]) = [0.33333, 0.5, 1, 0]
6868 # balanced_accuracy = ([0, 1, 1, 1] + [0.33333, 0.5, 1, 0]) / 2 = ([0.16666667, 0.75, 1, 0.5]
6969 # mean bal accuracy = 0.6041666666666666
7070
71- balanced_acc_score = self .__get_custom_metric_score (preds , label , label .shape [1 ])
71+ balanced_acc_score = self .__get_custom_metric_score (
72+ preds , label , label .shape [1 ]
73+ )
7274
7375 self .assertAlmostEqual (balanced_acc_score , 0.6041666666 , places = 4 )
7476
@@ -78,7 +80,10 @@ def test_all_predictions_are_1_half_labels_are_1(self):
7880 preds = torch .ones ((1 , 900 ), dtype = torch .int )
7981 label = torch .ones ((1 , 900 ), dtype = torch .int )
8082
81- mask = [[True ] * (label .size (1 ) // 2 ) + [False ] * (label .size (1 ) - (label .size (1 ) // 2 ))]
83+ mask = [
84+ [True ] * (label .size (1 ) // 2 )
85+ + [False ] * (label .size (1 ) - (label .size (1 ) // 2 ))
86+ ]
8287 random .shuffle (mask [0 ])
8388 label [torch .tensor (mask )] = 0
8489
@@ -88,7 +93,9 @@ def test_all_predictions_are_1_half_labels_are_1(self):
8893 # tnr = tn / (tn + fp) = [0, 0, 0, 0]
8994 # balanced accuracy = 1 / 4 = 0.25
9095
91- balanced_acc_custom_score = self .__get_custom_metric_score (preds , label , label .shape [1 ])
96+ balanced_acc_custom_score = self .__get_custom_metric_score (
97+ preds , label , label .shape [1 ]
98+ )
9299 self .assertAlmostEqual (balanced_acc_custom_score , 0.25 , places = 4 )
93100
94101 def test_all_labels_are_1_half_predictions_are_1 (self ):
@@ -97,7 +104,10 @@ def test_all_labels_are_1_half_predictions_are_1(self):
97104 preds = torch .ones ((1 , 900 ), dtype = torch .int )
98105 label = torch .ones ((1 , 900 ), dtype = torch .int )
99106
100- mask = [[True ] * (label .size (1 ) // 2 ) + [False ] * (label .size (1 ) - (label .size (1 ) // 2 ))]
107+ mask = [
108+ [True ] * (label .size (1 ) // 2 )
109+ + [False ] * (label .size (1 ) - (label .size (1 ) // 2 ))
110+ ]
101111 random .shuffle (mask [0 ])
102112 preds [torch .tensor (mask )] = 0
103113
@@ -107,7 +117,9 @@ def test_all_labels_are_1_half_predictions_are_1(self):
107117 # tnr = tn / (tn + fp) = [0, 0, 0, 0]
108118 # balanced accuracy = 1 / 4 = 0.25
109119
110- balanced_acc_custom_score = self .__get_custom_metric_score (preds , label , label .shape [1 ])
120+ balanced_acc_custom_score = self .__get_custom_metric_score (
121+ preds , label , label .shape [1 ]
122+ )
111123 self .assertAlmostEqual (balanced_acc_custom_score , 0.25 , places = 4 )
112124
113125 @staticmethod
0 commit comments