@@ -14,12 +14,13 @@ class CM:
14
14
tn : float = 0
15
15
16
16
class Metrics :
17
- def __init__ (self , df , experiment_name = "unk" ):
17
+ def __init__ (self , df , experiment_name = "unk" , topk_metrics = False ):
18
18
# TODO fix this, it mask the fact that our model may return more values than it should for "model
19
19
#self.df = df[~df["model_type_gold"].str.contains('not-present') | df["model_type_pred"].str.contains('model-best')]
20
20
self .df = df [df ["model_type_gold" ].str .contains ('model-best' ) | df ["model_type_pred" ].str .contains ('model-best' )]
21
21
self .experiment_name = experiment_name
22
22
self .metric_type = 'best'
23
+ self .topk_metrics = topk_metrics
23
24
24
25
def matching (self , * col_names ):
25
26
return np .all ([self .df [f"{ name } _pred" ] == self .df [f"{ name } _gold" ] for name in col_names ], axis = 0 )
@@ -42,6 +43,11 @@ def binary_confusion_matrix(self, *col_names, best_only=True):
42
43
gold_positive = relevant_gold
43
44
equal = self .matching (* col_names )
44
45
46
+ if self .topk_metrics :
47
+ equal = pd .Series (equal , index = pred_positive .index ).groupby ('cell_ext_id' ).max ()
48
+ pred_positive = pred_positive .groupby ('cell_ext_id' ).head (1 )
49
+ gold_positive = gold_positive .groupby ('cell_ext_id' ).head (1 )
50
+
45
51
tp = (equal & pred_positive & gold_positive ).sum ()
46
52
tn = (equal & ~ pred_positive & ~ gold_positive ).sum ()
47
53
fp = (pred_positive & (~ equal | ~ gold_positive )).sum ()
@@ -136,4 +142,4 @@ def show(self, df):
136
142
pd .set_option ('display.max_colwidth' , old_width )
137
143
138
144
def show_errors (self ):
139
- self .show (self .errors ())
145
+ self .show (self .errors ())
0 commit comments