Skip to content

Commit eb2f1bc

Browse files
committed
Code for initial NBSVM baseline
1 parent 254f473 commit eb2f1bc

File tree

4 files changed

+266
-0
lines changed

4 files changed

+266
-0
lines changed

sota_extractor2/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@
1414

1515

1616
elastic = dict(hosts=['localhost'], timeout=20)
17+
18+
19+
arxiv = data/'arxiv'
20+
htmls_raw = arxiv/'htmls'
21+
htmls_clean = arxiv/'htmls-clean'
22+
23+
datasets = data/"datasets"
24+
datasets_structure = datasets/"structure"

sota_extractor2/data/structure.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import re
2+
import pandas as pd
3+
from collections import namedtuple
4+
import hashlib
5+
from fastai.text import progress_bar
6+
from .elastic import Fragment
7+
from .json import *
8+
9+
def get_all_tables(papers):
10+
for paper in papers:
11+
for table in paper.table_set.all():
12+
if 'trash' not in table.gold_tags and table.gold_tags != '':
13+
table.paper_id = paper.arxiv_id
14+
yield table
15+
16+
def consume_cells(*matrix):
17+
Cell = namedtuple('AnnCell', 'row col vals')
18+
for row_id, row in enumerate(zip(*matrix)):
19+
for col_id, cell_val in enumerate(zip(*row)):
20+
yield Cell(row=row_id, col=col_id, vals=cell_val)
21+
22+
23+
def fetch_evidence(cell_content, paper_id, paper_limit=10, corpus_limit=10):
24+
evidence_query = Fragment.search().highlight(
25+
'text', pre_tags="<b>", post_tags="</b>", fragment_size=400)
26+
cell_content = cell_content.replace("\xa0", " ")
27+
query = {
28+
"query": cell_content,
29+
"slop": 2
30+
}
31+
paper_fragments = list(evidence_query
32+
.filter('term', paper_id=paper_id)
33+
.query('match_phrase', text=query)[:paper_limit])
34+
other_fagements = list(evidence_query
35+
.exclude('term', paper_id=paper_id)
36+
.query('match_phrase', text=query)[:corpus_limit])
37+
return paper_fragments + other_fagements
38+
39+
fix_refs_re = re.compile('\(\?\)|\s[?]+(\s|$)')
40+
41+
42+
def fix_refs(text):
43+
return fix_refs_re.sub(' xref-unkown ', fix_refs_re.sub(' xref-unkown ', text))
44+
45+
46+
highlight_re = re.compile("</?b>")
47+
48+
49+
def create_evidence_records(textfrag, cell, table):
50+
for text_highlited in textfrag.meta['highlight']['text']:
51+
text_highlited = fix_refs(text_highlited)
52+
text = highlight_re.sub("", text_highlited)
53+
text_sha1 = hashlib.sha1(text.encode("utf-8")).hexdigest()
54+
55+
cell_ext_id = f"{table.ext_id}/{cell.row}/{cell.col}"
56+
57+
if len(text.split()) > 50:
58+
yield {"text_sha1": text_sha1,
59+
"text_highlited": text_highlited,
60+
"text": text,
61+
"cell_type": cell.vals[1],
62+
"cell_content": fix_refs(cell.vals[0]),
63+
"this_paper": textfrag.paper_id == table.paper_id,
64+
"row": cell.row,
65+
"col": cell.col,
66+
"ext_id": cell_ext_id
67+
#"table_id":table_id
68+
}
69+
70+
71+
def filter_cells(cell):
72+
return re.search("[a-zA-Z]{2,}", cell.vals[1]) is not None
73+
74+
75+
def evidence_for_table(table, paper_limit=10, corpus_limit=1):
76+
records = [
77+
record
78+
for cell in consume_cells(table.matrix, table.matrix_gold_tags) if filter_cells(cell)
79+
for evidence in fetch_evidence(cell.vals[0], paper_id=table.paper_id, paper_limit=paper_limit, corpus_limit=corpus_limit)
80+
for record in create_evidence_records(evidence, cell, table=table)
81+
]
82+
df = pd.DataFrame.from_records(records)
83+
return df
84+
85+
86+
def evidence_for_tables(tables, paper_limit=100, corpus_limit=20):
87+
return pd.concat([evidence_for_table(table, paper_limit=paper_limit, corpus_limit=corpus_limit) for table in progress_bar(tables)])
88+
89+
def prepare_data(tables, csv_path):
90+
df = evidence_for_tables(tables)
91+
df = df.drop_duplicates(
92+
["cell_content", "text_highlited", "cell_type", "this_paper"])
93+
print("Number of text fragments ", len(df))
94+
csv_path.parent.mkdir(parents=True, exist_ok=True)
95+
df.to_csv(csv_path, index=None)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
from ...helpers.training import set_seed
3+
4+
5+
def split_by_cell_content(df, seed=42, split_column="cell_content"):
6+
set_seed(seed, "val_split")
7+
contents = np.random.permutation(df[split_column].unique())
8+
val_split = int(len(contents)*0.1)
9+
val_keys = contents[:val_split]
10+
split = df[split_column].isin(val_keys)
11+
valid_df = df[split]
12+
train_df = df[~split]
13+
len(train_df), len(valid_df)
14+
return train_df, valid_df
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import re
2+
import string
3+
from fastai.text import * # just for utilty functions pd, np, Path etc.
4+
5+
from sklearn.linear_model import LogisticRegression
6+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7+
8+
from ...helpers.training import set_seed
9+
10+
def transform_df(df):
11+
df=df.replace(re.compile(r"(xxref|xxanchor)-[\w\d-]*"), "\\1 ")
12+
df=df.replace(re.compile(r"(^|[ ])\d+\.\d+\b"), " xxnum ")
13+
df=df.replace(re.compile(r"(^|[ ])\d\b"), " xxnum ")
14+
df=df.replace(re.compile(r"\bdata set\b"), " dataset ")
15+
df = df.drop_duplicates(["text", "cell_content", "cell_type"]).fillna("")
16+
return df
17+
18+
def train_valid_split(df, seed=42, by="cell_content"):
19+
set_seed(seed, "val_split")
20+
contents = np.random.permutation(df[by].unique())
21+
val_split = int(len(contents)*0.1)
22+
val_keys = contents[:val_split]
23+
split = df[by].isin(val_keys)
24+
valid_df = df[split]
25+
train_df = df[~split]
26+
len(train_df), len(valid_df)
27+
return train_df, valid_df
28+
29+
def get_class_column(y, classIdx):
30+
if len(y.shape) == 1:
31+
return y == classIdx
32+
else:
33+
return y.iloc[:, classIdx]
34+
35+
def get_number_of_classes(y):
36+
if len(y.shape) == 1:
37+
return len(np.unique(y))
38+
else:
39+
return y.shape[1]
40+
41+
class NBSVM:
42+
def __init__(self, solver='liblinear', dual=True):
43+
self.solver = solver # 'lbfgs' - large, liblinear for small datasets
44+
self.dual = dual
45+
pass
46+
47+
re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])')
48+
49+
def tokenize(self, s):
50+
return self.re_tok.sub(r' \1 ', s).split()
51+
52+
def pr(self, y_i, y):
53+
p = self.trn_term_doc[y == y_i].sum(0)
54+
return (p+1) / ((y == y_i).sum()+1)
55+
56+
def get_mdl(self, y):
57+
y = y.values
58+
r = np.log(self.pr(1, y) / self.pr(0, y))
59+
m = LogisticRegression(C=4, dual=self.dual, solver=self.solver, max_iter=1000)
60+
x_nb = self.trn_term_doc.multiply(r)
61+
return m.fit(x_nb, y), r
62+
63+
def bow(self, X_train):
64+
self.n = X_train.shape[0]
65+
self.vec = TfidfVectorizer(ngram_range=(1, 2), tokenizer=self.tokenize,
66+
min_df=3, max_df=0.9, strip_accents='unicode', use_idf=1,
67+
smooth_idf=1, sublinear_tf=1)
68+
return self.vec.fit_transform(X_train)
69+
70+
def train_models(self, y_train):
71+
self.models = []
72+
for i in range(0, self.c):
73+
print('fit', i)
74+
m, r = self.get_mdl(get_class_column(y_train, i))
75+
self.models.append((m, r))
76+
77+
def fit(self, X_train, y_train):
78+
self.trn_term_doc = self.bow(X_train)
79+
self.c = get_number_of_classes(y_train)
80+
self.train_models(y_train)
81+
82+
def predict_proba(self, X_test):
83+
preds = np.zeros((len(X_test), self.c))
84+
test_term_doc = self.vec.transform(X_test)
85+
for i in range(0, self.c):
86+
m, r = self.models[i]
87+
preds[:, i] = m.predict_proba(test_term_doc.multiply(r))[:, 1]
88+
return preds
89+
90+
def validate(self, X_test, y_test):
91+
acc = (np.argmax(self.predict_proba(X_test), axis=1) == y_test).mean()
92+
return acc
93+
94+
def metrics(preds, true_y):
95+
y = true_y
96+
p = preds
97+
acc = (p == y).mean()
98+
tp = ((y != 0) & (p == y)).sum()
99+
fp = ((p != 0) & (p != y)).sum()
100+
prec = tp / (fp + tp)
101+
return {
102+
"precision": prec,
103+
"accuracy": acc,
104+
"TP": tp,
105+
"FP": fp,
106+
}
107+
108+
109+
def preds_for_cell_content(test_df, probs, group_by=["cell_content"]):
110+
test_df = test_df.copy()
111+
test_df["pred"] = np.argmax(probs, axis=1)
112+
grouped_preds = test_df.groupby(group_by)["pred"].agg(
113+
lambda x: x.value_counts().index[0])
114+
grouped_counts = test_df.groupby(group_by)["pred"].count()
115+
results = pd.DataFrame({'true': test_df.groupby(group_by)["label"].agg(lambda x: x.value_counts().index[0]),
116+
'pred': grouped_preds,
117+
'counts': grouped_counts})
118+
return results
119+
120+
def preds_for_cell_content_multi(test_df, probs, group_by=["cell_content"]):
121+
test_df = test_df.copy()
122+
probs_df = pd.DataFrame(probs, index=test_df.index)
123+
test_df = pd.concat([test_df, probs_df], axis=1)
124+
grouped_preds = np.argmax(test_df.groupby(
125+
group_by)[probs_df.columns].sum().values, axis=1)
126+
grouped_counts = test_df.groupby(group_by)["label"].count()
127+
results = pd.DataFrame({'true': test_df.groupby(group_by)["label"].agg(lambda x: x.value_counts().index[0]),
128+
'pred': grouped_preds,
129+
'counts': grouped_counts})
130+
return results
131+
132+
def test_model(model, tdf):
133+
probs = model(tdf["text"])
134+
preds = np.argmax(probs, axis=1)
135+
print("Results of categorisation on text fagment level")
136+
print(metrics(preds, tdf.label))
137+
138+
print("Results per cell_content grouped using majority voting")
139+
results = preds_for_cell_content(tdf, probs)
140+
print(metrics(results["pred"], results["true"]))
141+
142+
print("Results per cell_content grouped with multi category mean")
143+
results = preds_for_cell_content_multi(tdf, probs)
144+
print(metrics(results["pred"], results["true"]))
145+
146+
print("Results per cell_content grouped with multi category mean - only on fragments from the same paper that the coresponding table")
147+
results = preds_for_cell_content_multi(
148+
tdf[tdf.this_paper], probs[tdf.this_paper])
149+
print(metrics(results["pred"], results["true"]))

0 commit comments

Comments
 (0)