Skip to content

Commit f931cf6

Browse files
author
Marcin Kardas
committed
Compute top-k recall
1 parent 292a037 commit f931cf6

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

sota_extractor2/helpers/explainers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,14 @@ def _get_sota_records(self, paper):
168168
records.index.rename("cell_ext_id", inplace=True)
169169
return records
170170

171-
def linking_metrics(self, experiment_name="unk"):
171+
def linking_metrics(self, experiment_name="unk", topk_metrics=False, filtered=True):
172172
paper_ids = list(self.le.proposals.keys())
173173

174174
proposals = pd.concat(self.le.proposals.values())
175-
proposals = proposals[~proposals.index.isin(self.fe.reason.index)]
175+
176+
# if not topk_metrics:
177+
if filtered:
178+
proposals = proposals[~proposals.index.isin(self.fe.reason.index)]
176179

177180
papers = {paper_id: self.paper_collection.get_by_id(paper_id) for paper_id in paper_ids}
178181
missing = [paper_id for paper_id, paper in papers.items() if paper is None]
@@ -202,9 +205,10 @@ def linking_metrics(self, experiment_name="unk"):
202205
if "experiment_name" in df.columns:
203206
del df["experiment_name"]
204207

205-
metrics = Metrics(df, experiment_name=experiment_name)
208+
metrics = Metrics(df, experiment_name=experiment_name, topk_metrics=topk_metrics)
206209
return metrics
207210

211+
208212
def optimize_filters(self, metrics_info):
209213
results = optimize_filters(self, metrics_info)
210214
return results

sota_extractor2/models/linking/metrics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ class CM:
1414
tn: float = 0
1515

1616
class Metrics:
17-
def __init__(self, df, experiment_name="unk"):
17+
def __init__(self, df, experiment_name="unk", topk_metrics=False):
1818
# TODO fix this, it mask the fact that our model may return more values than it should for "model
1919
#self.df = df[~df["model_type_gold"].str.contains('not-present') | df["model_type_pred"].str.contains('model-best')]
2020
self.df = df[df["model_type_gold"].str.contains('model-best') | df["model_type_pred"].str.contains('model-best')]
2121
self.experiment_name = experiment_name
2222
self.metric_type = 'best'
23+
self.topk_metrics = topk_metrics
2324

2425
def matching(self, *col_names):
2526
return np.all([self.df[f"{name}_pred"] == self.df[f"{name}_gold"] for name in col_names], axis=0)
@@ -42,6 +43,11 @@ def binary_confusion_matrix(self, *col_names, best_only=True):
4243
gold_positive = relevant_gold
4344
equal = self.matching(*col_names)
4445

46+
if self.topk_metrics:
47+
equal = pd.Series(equal, index=pred_positive.index).groupby('cell_ext_id').max()
48+
pred_positive = pred_positive.groupby('cell_ext_id').head(1)
49+
gold_positive = gold_positive.groupby('cell_ext_id').head(1)
50+
4551
tp = (equal & pred_positive & gold_positive).sum()
4652
tn = (equal & ~pred_positive & ~gold_positive).sum()
4753
fp = (pred_positive & (~equal | ~gold_positive)).sum()
@@ -136,4 +142,4 @@ def show(self, df):
136142
pd.set_option('display.max_colwidth', old_width)
137143

138144
def show_errors(self):
139-
self.show(self.errors())
145+
self.show(self.errors())

0 commit comments

Comments
 (0)