|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 2 | + |
| 3 | +import re |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | +from axcell.data.paper_collection import remove_arxiv_version |
| 7 | + |
| 8 | + |
| 9 | +def norm_score_str(x): |
| 10 | + x = str(x) |
| 11 | + if re.match('^(\+|-|)(\d+)\.9{5,}$', x): |
| 12 | + x = re.sub('^(\+|-|)(\d+)\.9{5,}$', lambda a: a.group(1)+str(int(a.group(2))+1), x) |
| 13 | + elif x.endswith('9' * 5) and '.' in x: |
| 14 | + x = re.sub(r'([0-8])9+$', lambda a: str(int(a.group(1))+1), x) |
| 15 | + if '.' in x: |
| 16 | + x = re.sub(r'0+$', '', x) |
| 17 | + if x[-1] == '.': |
| 18 | + x = x[:-1] |
| 19 | + if x == '-0': |
| 20 | + x = '0' |
| 21 | + return x |
| 22 | + |
| 23 | + |
| 24 | +epsilon = 1e-10 |
| 25 | + |
| 26 | + |
| 27 | +def precision(tp, fp): |
| 28 | + pred_positives = tp + fp + epsilon |
| 29 | + return ((1.0 * tp) / pred_positives) |
| 30 | + |
| 31 | + |
| 32 | +def recall(tp, fn): |
| 33 | + true_positives = tp + fn + epsilon |
| 34 | + return ((1.0 * tp) / true_positives) |
| 35 | + |
| 36 | + |
| 37 | +def f1(prec, recall): |
| 38 | + norm = prec + recall + epsilon |
| 39 | + return (2 * prec * recall / norm) |
| 40 | + |
| 41 | + |
| 42 | +def stats(predictions, ground_truth, axis=None): |
| 43 | + gold = pd.DataFrame(ground_truth, columns=["paper", "task", "dataset", "metric", "value"]) |
| 44 | + pred = pd.DataFrame(predictions, columns=["paper", "task", "dataset", "metric", "value"]) |
| 45 | + |
| 46 | + if axis == 'tdm': |
| 47 | + columns = ['paper', 'task', 'dataset', 'metric'] |
| 48 | + elif axis == 'tdms' or axis is None: |
| 49 | + columns = ['paper', 'task', 'dataset', 'metric', 'value'] |
| 50 | + else: |
| 51 | + columns = ['paper', axis] |
| 52 | + gold = gold[columns].drop_duplicates() |
| 53 | + pred = pred[columns].drop_duplicates() |
| 54 | + |
| 55 | + results = gold.merge(pred, on=columns, how="outer", indicator=True) |
| 56 | + |
| 57 | + is_correct = results["_merge"] == "both" |
| 58 | + no_pred = results["_merge"] == "left_only" |
| 59 | + no_gold = results["_merge"] == "right_only" |
| 60 | + |
| 61 | + results["TP"] = is_correct.astype('int8') |
| 62 | + results["FP"] = no_gold.astype('int8') |
| 63 | + results["FN"] = no_pred.astype('int8') |
| 64 | + |
| 65 | + m = results.groupby(["paper"]).agg({"TP": "sum", "FP": "sum", "FN": "sum"}) |
| 66 | + m["precision"] = precision(m.TP, m.FP) |
| 67 | + m["recall"] = recall(m.TP, m.FN) |
| 68 | + m["f1"] = f1(m.precision, m.recall) |
| 69 | + |
| 70 | + TP_ALL = m.TP.sum() |
| 71 | + FP_ALL = m.FP.sum() |
| 72 | + FN_ALL = m.FN.sum() |
| 73 | + |
| 74 | + prec, reca = precision(TP_ALL, FP_ALL), recall(TP_ALL, FN_ALL) |
| 75 | + return { |
| 76 | + 'Micro Precision': prec, |
| 77 | + 'Micro Recall': reca, |
| 78 | + 'Micro F1': f1(prec, reca), |
| 79 | + 'Macro Precision': m.precision.mean(), |
| 80 | + 'Macro Recall': m.recall.mean(), |
| 81 | + 'Macro F1': m.f1.mean() |
| 82 | + } |
| 83 | + |
| 84 | + |
| 85 | +def evaluate(predictions, ground_truth): |
| 86 | + predictions = predictions.copy() |
| 87 | + ground_truth = ground_truth.copy() |
| 88 | + predictions['value'] = predictions['score' if 'score' in predictions else 'value'].apply(norm_score_str) |
| 89 | + ground_truth['value'] = ground_truth['score' if 'score' in ground_truth else 'value'].apply(norm_score_str) |
| 90 | + predictions['paper'] = predictions['arxiv_id'].apply(remove_arxiv_version) |
| 91 | + ground_truth['paper'] = ground_truth['arxiv_id'].apply(remove_arxiv_version) |
| 92 | + |
| 93 | + metrics = [] |
| 94 | + for axis in [None, "tdm", "task", "dataset", "metric"]: |
| 95 | + s = stats(predictions, ground_truth, axis) |
| 96 | + s['type'] = {'tdms': 'TDMS', 'tdm': 'TDM', 'task': 'Task', 'dataset': 'Dataset', 'metric': 'Metric'}.get(axis) |
| 97 | + metrics.append(s) |
| 98 | + columns = ['Micro Precision', 'Micro Recall', 'Micro F1', 'Macro Precision', 'Macro Recall', 'Macro F1'] |
| 99 | + return pd.DataFrame(metrics, columns=columns) |
0 commit comments