33from dataclasses import dataclass
44import enum
55import pathlib
6- from typing import Dict , TextIO , cast
6+ from typing import TextIO , cast
77
88import pandas as pd
99
@@ -23,7 +23,7 @@ class ClassificationMetric(enum.Enum):
2323class ClassificationScore (Score ):
2424 per_category : pd .DataFrame
2525 macro_average : pd .Series
26- rocs : Dict [str , pd .DataFrame ]
26+ rocs : dict [str , pd .DataFrame ]
2727 aggregate : pd .Series
2828
2929 def __init__ (
@@ -40,28 +40,29 @@ def __init__(
4040 self ._category_score (
4141 truth_probabilities [category ],
4242 prediction_probabilities [category ],
43- truth_weights . score_weight ,
43+ truth_weights [ ' score_weight' ] ,
4444 category ,
4545 )
4646 for category in categories
4747 ]
4848 )
49- self .macro_average = self .per_category .mean (axis = 'index' ).rename (
50- 'macro_average' , inplace = True
51- )
49+ # TODO: Fixed by https://github.com/pandas-dev/pandas-stubs/pull/1105
50+ self .macro_average = self .per_category .mean ( # type: ignore[assignment]
51+ axis = 'index'
52+ ).rename ('macro_average' , inplace = True )
5253 self .rocs = {
5354 category : metrics .roc (
5455 truth_probabilities [category ],
5556 prediction_probabilities [category ],
56- truth_weights . score_weight ,
57+ truth_weights [ ' score_weight' ] ,
5758 )
5859 for category in categories
5960 }
6061 # Multi-category aggregate metrics
6162 self .aggregate = pd .Series (
6263 {
6364 'balanced_accuracy' : metrics .balanced_multiclass_accuracy (
64- truth_probabilities , prediction_probabilities , truth_weights . score_weight
65+ truth_probabilities , prediction_probabilities , truth_weights [ ' score_weight' ]
6566 )
6667 },
6768 index = ['balanced_accuracy' ],
@@ -71,29 +72,29 @@ def __init__(
7172 if target_metric == ClassificationMetric .BALANCED_ACCURACY :
7273 self .overall = self .aggregate .at ['balanced_accuracy' ]
7374 self .validation = metrics .balanced_multiclass_accuracy (
74- truth_probabilities , prediction_probabilities , truth_weights . validation_weight
75+ truth_probabilities , prediction_probabilities , truth_weights [ ' validation_weight' ]
7576 )
7677 elif target_metric == ClassificationMetric .AVERAGE_PRECISION :
77- self .overall = self .macro_average ['ap' ]
78+ self .overall = self .macro_average . at ['ap' ]
7879 per_category_ap = pd .Series (
7980 [
8081 metrics .average_precision (
8182 truth_probabilities [category ],
8283 prediction_probabilities [category ],
83- truth_weights . validation_weight ,
84+ truth_weights [ ' validation_weight' ] ,
8485 )
8586 for category in categories
8687 ]
8788 )
8889 self .validation = per_category_ap .mean ()
8990 elif target_metric == ClassificationMetric .AUC :
90- self .overall = self .macro_average ['auc' ]
91+ self .overall = self .macro_average . at ['auc' ]
9192 per_category_auc = pd .Series (
9293 [
9394 metrics .auc (
9495 truth_probabilities [category ],
9596 prediction_probabilities [category ],
96- truth_weights . validation_weight ,
97+ truth_weights [ ' validation_weight' ] ,
9798 )
9899 for category in categories
99100 ]
@@ -212,9 +213,10 @@ def from_file(
212213 prediction_file : pathlib .Path ,
213214 target_metric : ClassificationMetric ,
214215 ) -> ClassificationScore :
215- with truth_file .open ('r' ) as truth_file_stream , prediction_file .open (
216- 'r'
217- ) as prediction_file_stream :
216+ with (
217+ truth_file .open ('r' ) as truth_file_stream ,
218+ prediction_file .open ('r' ) as prediction_file_stream ,
219+ ):
218220 return cls .from_stream (
219221 truth_file_stream ,
220222 prediction_file_stream ,
0 commit comments