|
| 1 | +import pandas as pd, numpy as np |
| 2 | +from dataclasses import dataclass, replace |
| 3 | +from sota_extractor2.models.linking.metrics import CM |
| 4 | +from matplotlib import pyplot as plt |
| 5 | + |
| 6 | + |
| 7 | +def annotations(matrix, structure, r, c, type='model'): |
| 8 | + ann = [] |
| 9 | + for nc in range(0, c): |
| 10 | + if type in structure[r, nc]: |
| 11 | + ann.append(matrix[r, nc]) |
| 12 | + for nr in range(0, r): |
| 13 | + if type in structure[nr, c]: |
| 14 | + ann.append(matrix[nr, c]) |
| 15 | + return ' '.join(ann) |
| 16 | + |
| 17 | + |
| 18 | +def estimate_noises(extracted_values, gold_values, short_forms): |
| 19 | + if not len(extracted_values): |
| 20 | + return {} |
| 21 | + extracted_values = set(extracted_values) |
| 22 | + gold_values = set(gold_values) |
| 23 | + |
| 24 | + return {gold: 1 - len(extracted_values & set(short_forms.get(gold, set()))) / len(extracted_values) for gold in |
| 25 | + gold_values} |
| 26 | + |
| 27 | + |
| 28 | +def estimate_context_noise(context, records): |
| 29 | + context = context or "" |
| 30 | + abbrvs = context_search.extract_acronyms(context) |
| 31 | + context = normalize_cell_ws(normalize_dataset(context)) |
| 32 | + dss = set(cs.find_datasets(context)) | set(abbrvs.keys()) |
| 33 | + mss = set(cs.find_metrics(context)) |
| 34 | + dss -= mss |
| 35 | + dss = set([normalize_cell(ds) for ds in dss]) |
| 36 | + mss = set([normalize_cell(ms) for ms in mss]) |
| 37 | + |
| 38 | + gold_ds = set(records.dataset.values) |
| 39 | + gold_ms = set(records.metric.values) |
| 40 | + ds_noises = estimate_noises(dss, gold_ds, cs.datasets) |
| 41 | + ms_noises = estimate_noises(mss, gold_ms, cs.metrics) |
| 42 | + |
| 43 | + return ds_noises, ms_noises |
| 44 | + |
| 45 | + |
| 46 | +def estimate_paper_context_noise(paper, gold_sota_records): |
| 47 | + records = gold_sota_records[gold_sota_records.paper_id == paper.paper_id] |
| 48 | + datasets = de.from_paper(paper) |
| 49 | + context = " ".join(datasets) |
| 50 | + return estimate_context_noise(context, records) |
| 51 | + |
| 52 | + |
| 53 | +def estimate_caption_context_noise(paper, table, gold_sota_records): |
| 54 | + table_ext_id = f"{paper.paper_id}/{table.name}/" |
| 55 | + records = gold_sota_records[gold_sota_records.index.str.startswith(table_ext_id)] |
| 56 | + return estimate_context_noise(table.caption, records) |
| 57 | + |
| 58 | + |
| 59 | +def estimate_cell_context_noise(paper, table, row, col, gold_sota_records): |
| 60 | + cell_ext_id = f"{paper.paper_id}/{table.name}/{row}.{col}" |
| 61 | + records = gold_sota_records[gold_sota_records.index == cell_ext_id] |
| 62 | + value = annotations(table.matrix.values, table.matrix_gold_tags.values, row, col, 'dataset') |
| 63 | + return estimate_context_noise(value, records) |
| 64 | + |
| 65 | + |
| 66 | +def average_dicts(dicts): |
| 67 | + sums = {} |
| 68 | + for d in dicts: |
| 69 | + for k, v in d.items(): |
| 70 | + sums.setdefault(k, []).append(v) |
| 71 | + return {k: np.mean(v) for k, v in sums.items()} |
| 72 | + |
| 73 | + |
| 74 | +def all_equal(row): |
| 75 | + cols = ["model_type", "dataset", "metric", "task", "parsed"] |
| 76 | + return np.all([row[f"{name}_pred"] == row[f"{name}_gold"] for name in cols]) |
| 77 | + |
| 78 | + |
| 79 | +def merge_gold_records(explainer): |
| 80 | + paper_ids = list(explainer.le.proposals.keys()) |
| 81 | + |
| 82 | + proposals = pd.concat(explainer.le.proposals.values()) |
| 83 | + |
| 84 | + papers = {paper_id: explainer.paper_collection.get_by_id(paper_id) for paper_id in paper_ids} |
| 85 | + missing = [paper_id for paper_id, paper in papers.items() if paper is None] |
| 86 | + if missing: |
| 87 | + print("Missing papers in paper collection:") |
| 88 | + print(", ".join(missing)) |
| 89 | + papers = [paper for paper in papers.values() if paper is not None] |
| 90 | + |
| 91 | + if explainer.gold_sota_records is None: |
| 92 | + print("gold_sota_records is missing") |
| 93 | + return |
| 94 | + else: |
| 95 | + gold_sota_records = explainer.gold_sota_records |
| 96 | + which = gold_sota_records.index.to_series().str.split("/", expand=True)[0] \ |
| 97 | + .isin([paper.paper_id for paper in papers]) |
| 98 | + gold_sota_records = gold_sota_records[which] |
| 99 | + |
| 100 | + df = gold_sota_records.merge(proposals, 'outer', left_index=True, right_index=True, suffixes=['_gold', '_pred']) |
| 101 | + df = df.reindex(sorted(df.columns), axis=1) |
| 102 | + df.confidence = df.confidence.fillna(0.0) |
| 103 | + df = df.fillna('not-present') |
| 104 | + df["equal"] = df.apply(all_equal, axis=1) |
| 105 | + df["pred_positive"] = df["model_type_pred"].str.contains("model-best") |
| 106 | + df["gold_positive"] = df["model_type_gold"].str.contains("model-best") |
| 107 | + return df |
| 108 | + |
| 109 | + |
| 110 | +def find_threshold_intervals(proposals, metrics_info, context="paper"): |
| 111 | + # maximal threshold to have this proposal returned |
| 112 | + proposals["max_threshold"] = proposals.confidence |
| 113 | + |
| 114 | + proposals["min_threshold"] = 0.0 |
| 115 | + |
| 116 | + ignore = (proposals.model_type_pred != 'model-best') | (proposals.struct_model_type == '') | \ |
| 117 | + (proposals.struct_dataset.str.contains('dev')) | (proposals.struct_dataset.str.contains('train')) |
| 118 | + |
| 119 | + # this proposal won't be ever returned due to structure or model type filters |
| 120 | + proposals.loc[ignore, "min_threshold"] = 1.0 |
| 121 | + proposals.loc[ignore, "max_threshold"] = 0.0 |
| 122 | + |
| 123 | + all_proposals = proposals |
| 124 | + proposals = proposals[~ignore] |
| 125 | + |
| 126 | + if context == "paper": |
| 127 | + context_column = proposals.index.to_series().str.split('/', expand=False).apply(lambda x: x[0]) |
| 128 | + else: |
| 129 | + context_column = proposals.index.to_series().str.split('/', expand=False).apply(lambda x: x[0] + "/" + x[1]) |
| 130 | + |
| 131 | + for i, p in proposals.iterrows(): |
| 132 | + key = (p.task_pred, p.dataset_pred, p.metric_pred) |
| 133 | + proposals_context = proposals[context_column == context_column[p.name]] |
| 134 | + proposals_context = proposals_context[~proposals_context.parsed_pred.isna()] |
| 135 | + proposals_context = proposals_context[ |
| 136 | + (proposals_context.task_pred == p.task_pred) & |
| 137 | + (proposals_context.dataset_pred == p.dataset_pred) & |
| 138 | + (proposals_context.metric_pred == p.metric_pred) |
| 139 | + ] |
| 140 | + d = 0 |
| 141 | + if key in metrics_info: |
| 142 | + d = metrics_info[key] |
| 143 | + elif p.metric_pred in metrics_info: |
| 144 | + d = metrics_info[p.metric_pred] |
| 145 | + elif 'error' in p.metric_pred.lower(): |
| 146 | + d = -1 |
| 147 | + elif 'accuracy' in p.metric_pred.lower(): |
| 148 | + d = 1 |
| 149 | + |
| 150 | + if d >= 0: |
| 151 | + d = 1 |
| 152 | + else: |
| 153 | + d = -1 |
| 154 | + |
| 155 | + # the minimal threshold above which all superior results are ignored |
| 156 | + which = d * proposals_context.parsed_pred > d * p.parsed_pred |
| 157 | + if np.any(which.values): |
| 158 | + all_proposals.at[i, "min_threshold"] = proposals_context[which].confidence.values.max() |
| 159 | + else: |
| 160 | + which = proposals_context[proposals_context.parsed_pred == p.parsed_pred].iloc[0] |
| 161 | + if which.name != p.name: |
| 162 | + all_proposals.at[i, "min_threshold"] = which.confidence |
| 163 | + |
| 164 | + return all_proposals |
| 165 | + |
| 166 | + |
| 167 | +def update_cm(proposal, cm, is_activated): |
| 168 | + d = 1 if is_activated else -1 |
| 169 | + if proposal.equal and proposal.pred_positive and proposal.gold_positive: |
| 170 | + cm = replace(cm, tp=cm.tp + d, fn=cm.fn - d) |
| 171 | + if proposal.equal and not proposal.pred_positive and not proposal.gold_positive: |
| 172 | + cm = replace(cm, tn=cm.tn + d) |
| 173 | + if proposal.pred_positive and (not proposal.equal or not proposal.gold_positive): |
| 174 | + cm = replace(cm, fp=cm.fp + d) |
| 175 | + # if proposal.gold_positive and (not proposal.equal or not proposal.pred_positive): |
| 176 | + # cm = replace(cm, fn = cm.fn+d) |
| 177 | + return cm |
| 178 | + |
| 179 | + |
| 180 | +def sweep_thresholds(df): |
| 181 | + cm = CM(fn=sum(df.gold_positive)) |
| 182 | + df = df[df.min_threshold < df.max_threshold] |
| 183 | + |
| 184 | + sweeps = df.reset_index().melt(id_vars="cell_ext_id", value_vars=["min_threshold", "max_threshold"], |
| 185 | + var_name="threshold_type", value_name="threshold") |
| 186 | + |
| 187 | + sweeps = sweeps.sort_values(by=["threshold", "threshold_type"]).reset_index(drop=True) |
| 188 | + |
| 189 | + steps = sweeps.threshold.drop_duplicates().index |
| 190 | + |
| 191 | + results = [] |
| 192 | + for i, idx1 in enumerate(steps[:-1]): |
| 193 | + th1 = sweeps.threshold[idx1] |
| 194 | + |
| 195 | + to_restore = cm |
| 196 | + for j, idx2 in enumerate(steps[i + 1:], i + 1): |
| 197 | + th2 = sweeps.threshold[idx2] |
| 198 | + precision = cm.tp / (cm.tp + cm.fp + 1e-8) |
| 199 | + recall = cm.tp / (cm.tp + cm.fn + 1e-8) |
| 200 | + f1 = 2 * precision * recall / (precision + recall + 1e-8) |
| 201 | + |
| 202 | + result = dict(threshold1=th1, threshold2=sweeps.threshold[idx2 - 1], tp=cm.tp, tn=cm.tn, fp=cm.fp, fn=cm.fn, |
| 203 | + precision=precision, recall=recall, f1=f1) |
| 204 | + results.append(result) |
| 205 | + for _, row in sweeps[sweeps.threshold == sweeps.threshold[idx2 - 1]].iterrows(): |
| 206 | + proposal = df.loc[row.cell_ext_id] |
| 207 | + is_activated = row.threshold_type == 'min_threshold' |
| 208 | + if not is_activated and proposal.min_threshold < th1: |
| 209 | + cm = update_cm(proposal, cm, is_activated) |
| 210 | + |
| 211 | + precision = cm.tp / (cm.tp + cm.fp + 1e-8) |
| 212 | + recall = cm.tp / (cm.tp + cm.fn + 1e-8) |
| 213 | + f1 = 2 * precision * recall / (precision + recall + 1e-8) |
| 214 | + |
| 215 | + result = dict(threshold1=th1, threshold2=th2, tp=cm.tp, tn=cm.tn, fp=cm.fp, fn=cm.fn, |
| 216 | + precision=precision, recall=recall, f1=f1) |
| 217 | + results.append(result) |
| 218 | + |
| 219 | + cm = to_restore |
| 220 | + |
| 221 | + for _, row in sweeps[sweeps.threshold == th1].iterrows(): |
| 222 | + proposal = df.loc[row.cell_ext_id] |
| 223 | + |
| 224 | + is_activated = row.threshold_type == 'min_threshold' |
| 225 | + cm = update_cm(proposal, cm, is_activated) |
| 226 | + |
| 227 | + return df, sweeps, steps, pd.DataFrame(results) |
| 228 | + |
| 229 | + |
| 230 | +class PRResults: |
| 231 | + def __init__(self, results): |
| 232 | + self.results = results |
| 233 | + |
| 234 | + def plot(self): |
| 235 | + plt.figure(figsize=(6, 6)) |
| 236 | + plt.plot(self.results["precision"], self.results["recall"], '.') |
| 237 | + plt.xlabel("precision") |
| 238 | + plt.ylabel("recall") |
| 239 | + |
| 240 | + def _best(self, results, metric): |
| 241 | + b = results.loc[results[metric].idxmax()] |
| 242 | + x = ["precision", "recall", "f1"] |
| 243 | + x.remove(metric) |
| 244 | + y = [b[m] for m in x] |
| 245 | + print(f"Best {metric}={b[metric]:0.2f} (with {x[0]}={y[0]:.2f} and {x[1]}={y[1]:.2f})" |
| 246 | + f" is achieved with threshold1={b.threshold1} and threshold2={b.threshold2}") |
| 247 | + |
| 248 | + def best(self, min_precision=0, min_recall=0, min_f1=0): |
| 249 | + results = self.results |
| 250 | + results = results[ |
| 251 | + (results.precision >= min_precision) & |
| 252 | + (results.recall >= min_recall) & |
| 253 | + (results.f1 >= min_f1) |
| 254 | + ] |
| 255 | + if not len(results): |
| 256 | + print("No results with this criteria") |
| 257 | + else: |
| 258 | + self._best(results, "precision") |
| 259 | + self._best(results, "recall") |
| 260 | + self._best(results, "f1") |
| 261 | + |
| 262 | +def optimize_filters(explainer, metrics_info): |
| 263 | + df = merge_gold_records(explainer) |
| 264 | + df = find_threshold_intervals(df, metrics_info, context="paper") |
| 265 | + df, sweeps, steps, results = sweep_thresholds(df) |
| 266 | + return PRResults(results) |
0 commit comments