Skip to content

Commit 2bfa1f8

Browse files
committed
Introduce explainers
* implement Table deepcopy so we can store Tables in various stages of pipeline * add explainers for: + table type + filtering * add SessionRecorder to speed up testing * modify ULMFiT interpretation classes to show signed words contributions
1 parent 00fb305 commit 2bfa1f8

File tree

8 files changed

+283
-58
lines changed

8 files changed

+283
-58
lines changed

sota_extractor2/data/paper_collection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def __init__(self, paper_id, text, tables, annotations):
2828
else:
2929
self.gold_tags = ''
3030

31+
def table_by_name(self, name):
32+
for table in self.tables:
33+
if table.name == name:
34+
return table
35+
return None
36+
3137

3238
# todo: make sure multithreading/processing won't cause collisions
3339
def random_id():

sota_extractor2/data/table.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from dataclasses import dataclass, field
77
from typing import List
88
from ..helpers.jupyter import display_table
9+
from copy import deepcopy
10+
911

1012
@dataclass
1113
class Cell:
@@ -68,15 +70,18 @@ def read_str_csv(filename):
6870
return df
6971

7072

73+
class CellDataFrame(pd.DataFrame):
74+
"""We subclass pandas DataFrame in order to make deepcopy recursively copy cells"""
75+
def __deepcopy__(self, memodict={}):
76+
return CellDataFrame(self.applymap(lambda cell: deepcopy(cell, memodict)))
7177

7278

7379
class Table:
7480
def __init__(self, name, df, layout, caption=None, figure_id=None, annotations=None, migrate=False, old_name=None, guessed_tags=None):
7581
self.name = name
76-
self.df = df
7782
self.caption = caption
7883
self.figure_id = figure_id
79-
self.df = df.applymap(str2cell)
84+
self.df = CellDataFrame(df.applymap(str2cell))
8085

8186
if migrate:
8287
self.old_name = old_name

sota_extractor2/helpers/explainers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from ..models.structure import TableType
2+
from ..loggers import StructurePredictionEvaluator, LinkerEvaluator, FilteringEvaluator
3+
import pandas as pd
4+
5+
6+
class TableTypeExplainer:
7+
def __init__(self, paper, table, table_type, probs):
8+
self.paper = paper
9+
self.table = table
10+
self.table_type = table_type
11+
self.probs = pd.DataFrame(probs, columns=["type", "probability"])
12+
13+
def __str__(self):
14+
return f"Table {self.table.name} was labelled as {self.table_type}."
15+
16+
def display(self):
17+
print(self)
18+
self.probs.display()
19+
20+
21+
class Explainer:
22+
def __init__(self, pipeline_logger, paper_collection):
23+
self.spe = StructurePredictionEvaluator(pipeline_logger, paper_collection)
24+
self.le = LinkerEvaluator(pipeline_logger, paper_collection)
25+
self.fe = FilteringEvaluator(pipeline_logger)
26+
27+
def explain(self, paper, cell_ext_id):
28+
paper_id, table_name, rc = cell_ext_id.split('/')
29+
if paper.paper_id != paper_id:
30+
return "No such cell"
31+
32+
row, col = [int(x) for x in rc.split('.')]
33+
34+
table_type, probs = self.spe.get_table_type_predictions(paper_id, table_name)
35+
36+
if table_type == TableType.IRRELEVANT:
37+
return TableTypeExplainer(paper, paper.table_by_name(table_name), table_type, probs)
38+
39+
reason = self.fe.reason.get(cell_ext_id)
40+
if reason is None:
41+
pass
42+
else:
43+
return reason

sota_extractor2/helpers/interpret.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from fastai.text.interpret import TextClassificationInterpretation as AbsTextClassificationInterpretation, _eval_dropouts
2+
from fastai.basic_data import DatasetType
3+
import torch
4+
5+
6+
__all__ = ["TextClassificationInterpretation", "TextMultiClassificationInterpretation"]
7+
8+
9+
class TextClassificationInterpretation(AbsTextClassificationInterpretation):
10+
@classmethod
11+
def from_learner(cls, learner):
12+
empty_preds = torch.Tensor([[1]])
13+
return cls(learner, empty_preds, None, None)
14+
15+
def intrinsic_attention(self, text:str, class_id:int=None):
16+
"""Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.
17+
Similar as in base class, but does not apply abs() before summing gradients.
18+
"""
19+
self.model.train()
20+
_eval_dropouts(self.model)
21+
self.model.zero_grad()
22+
self.model.reset()
23+
ids = self.data.one_item(text)[0]
24+
emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)
25+
lstm_output = self.model[0].module(emb, from_embeddings=True)
26+
self.model.eval()
27+
cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1)
28+
if class_id is None: class_id = cl.argmax()
29+
cl[0][class_id].backward()
30+
# attn = emb.grad.squeeze().abs().sum(dim=-1)
31+
# attn /= attn.max()
32+
attn = emb.grad.squeeze().sum(dim=-1)
33+
attn = attn / attn.abs().max() * 0.5 + 0.5
34+
tokens = self.data.single_ds.reconstruct(ids[0])
35+
return tokens, attn
36+
37+
38+
class TextMultiClassificationInterpretation(TextClassificationInterpretation):
39+
def intrinsic_attention(self, text:str, class_id:int=None):
40+
"""Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.
41+
Similar as in base class, but uses sigmoid instead of softmax and does not apply abs() before summing gradients.
42+
"""
43+
self.model.train()
44+
_eval_dropouts(self.model)
45+
self.model.zero_grad()
46+
self.model.reset()
47+
ids = self.data.one_item(text)[0]
48+
emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)
49+
lstm_output = self.model[0].module(emb, from_embeddings=True)
50+
self.model.eval()
51+
cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].sigmoid()
52+
if class_id is None: class_id = cl.argmax()
53+
cl[0][class_id].backward()
54+
# attn = emb.grad.squeeze().abs().sum(dim=-1)
55+
# attn /= attn.max()
56+
attn = emb.grad.squeeze().sum(dim=-1)
57+
attn = attn / attn.abs().max() * 0.5 + 0.5
58+
tokens = self.data.single_ds.reconstruct(ids[0])
59+
return tokens, attn

sota_extractor2/loggers.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pandas as pd
33
from .models.structure.experiment import Experiment, label_map, Labels
44
from .models.structure.type_predictor import TableType
5+
from copy import deepcopy
6+
import pickle
57

68

79
class BaseLogger:
@@ -21,30 +23,85 @@ def __call__(self, step, **kwargs):
2123
print(f"[STEP] {step}: {kwargs}", file=self.file)
2224

2325

26+
class SessionRecorder:
27+
def __init__(self, pipeline_logger):
28+
self.pipeline_logger = pipeline_logger
29+
self.session = []
30+
self._recording = False
31+
32+
def __call__(self, step, **kwargs):
33+
self.session.append((step, deepcopy(kwargs)))
34+
35+
def reset(self):
36+
self.session = []
37+
38+
def record(self):
39+
if not self._recording:
40+
self.pipeline_logger.register(".*", self)
41+
self._recording = True
42+
43+
def stop(self):
44+
if self._recording:
45+
self.pipeline_logger.unregister(".*", self)
46+
self._recording = False
47+
48+
def replay(self):
49+
self.stop()
50+
for step, kwargs in self.session:
51+
self.pipeline_logger(step, **kwargs)
52+
53+
def save_session(self, path):
54+
with open(path, "wb") as f:
55+
pickle.dump(self.session, f)
56+
57+
def load_session(self, path):
58+
with open(path, "rb") as f:
59+
self.session = pickle.load(f)
60+
61+
2462
class StructurePredictionEvaluator:
2563
def __init__(self, pipeline_logger, pc):
26-
pipeline_logger.register("structure_prediction::tables_labelled", self.on_tables_labelled)
64+
pipeline_logger.register("structure_prediction::evidences_split", self.on_evidences_split)
65+
pipeline_logger.register("structure_prediction::tables_labeled", self.on_tables_labeled)
2766
pipeline_logger.register("type_prediction::predicted", self.on_type_predicted)
67+
pipeline_logger.register("type_prediction::multiclass_predicted", self.on_type_multiclass_predicted)
2868
self.pc = pc
2969
self.results = {}
3070
self.type_predictions = {}
71+
self.type_multiclass_predictions = {}
72+
self.evidences = pd.DataFrame()
73+
74+
def on_type_multiclass_predicted(self, step, paper, tables, threshold, predictions):
75+
for table, prediction in zip(tables, predictions):
76+
self.type_multiclass_predictions[paper.paper_id, table.name] = {
77+
TableType.SOTA: prediction[0],
78+
TableType.ABLATION: prediction[1],
79+
TableType.IRRELEVANT: threshold
80+
}
3181

3282
def on_type_predicted(self, step, paper, tables, predictions):
33-
self.type_predictions[paper.paper_id] = predictions
83+
for table, prediction in zip(tables, predictions):
84+
self.type_predictions[paper.paper_id, table.name] = prediction
85+
86+
def on_evidences_split(self, step, evidences, evidences_num):
87+
self.evidences = pd.concat([self.evidences, evidences])
3488

35-
def on_tables_labelled(self, step, paper, tables):
89+
def on_tables_labeled(self, step, paper, labeled_tables):
3690
golds = [p for p in self.pc if p.text.title == paper.text.title]
3791
paper_id = paper.paper_id
3892
type_results = []
3993
cells_results = []
94+
labeled_tables = {table.name: table for table in labeled_tables}
4095
if len(golds) == 1:
4196
gold = golds[0]
42-
for gold_table, table, table_type in zip(gold.tables, paper.tables, self.type_predictions.get(paper.paper_id, [])):
97+
for gold_table, table, in zip(gold.tables, paper.tables):
98+
table_type = self.type_predictions[paper.paper_id, table.name]
4399
is_important = table_type == TableType.SOTA or table_type == TableType.ABLATION
44100
gold_is_important = "sota" in gold_table.gold_tags or "ablation" in gold_table.gold_tags
45101
type_results.append({"predicted": is_important, "gold": gold_is_important, "name": table.name})
46102
if not is_important:
47103
continue
104+
table = labeled_tables[table.name]
48105
rows, cols = table.df.shape
49106
for r in range(rows):
50107
for c in range(cols):
@@ -76,6 +133,14 @@ def metrics(self, paper_id):
76133
e._set_results(paper_id, self.map_tags(results['cells'].predicted), self.map_tags(results['cells'].gold))
77134
e.show_results(paper_id, normalize=True)
78135

136+
def get_table_type_predictions(self, paper_id, table_name):
137+
prediction = self.type_predictions.get((paper_id, table_name))
138+
multi_predictions = self.type_multiclass_predictions.get((paper_id, table_name))
139+
if prediction is not None:
140+
multi_predictions = sorted(multi_predictions.items(), key=lambda x: x[1], reverse=True)
141+
return prediction, [(k.name, v) for k, v in multi_predictions
142+
]
143+
79144

80145
class LinkerEvaluator:
81146
def __init__(self, pipeline_logger, pc):
@@ -102,3 +167,18 @@ def on_taxonomy_topk(self, step, ext_id, topk):
102167

103168
def top_matches(self, paper_id, table_name, row, col):
104169
return self.topk[(paper_id, table_name, row, col)]
170+
171+
172+
class FilteringEvaluator:
173+
def __init__(self, pipeline_logger):
174+
pipeline_logger.register("filtering::.*::filtered", self.on_filtered)
175+
self.proposals = {}
176+
self.which = {}
177+
self.reason = pd.Series(dtype=str)
178+
179+
def on_filtered(self, step, proposals, which, reason, **kwargs):
180+
_, filter_step, _ = step.split('::')
181+
if filter_step != "compound_filtering":
182+
self.proposals[filter_step] = pd.concat(self.proposals.get(filter_step, []) + [proposals])
183+
self.which[filter_step] = pd.concat(self.which.get(filter_step, []) + [which])
184+
self.reason = self.reason.append(reason)

0 commit comments

Comments
 (0)