14
14
from sota_extractor2 .pipeline_logger import pipeline_logger
15
15
16
16
from sota_extractor2 .models .linking import manual_dicts
17
+ from collections import Counter
17
18
18
19
def dummy_item (reason ):
19
20
return pd .DataFrame (dict (dataset = [reason ], task = [reason ], metric = [reason ], evidence = ["" ], confidence = [0.0 ]))
@@ -64,10 +65,10 @@ def find_names(text, names_trie):
64
65
profile = EvidenceFinder .single_letter_re .sub ("x" , profile )
65
66
text = text .replace (" " , "" )
66
67
profile = profile .replace (" " , "" )
67
- s = set ()
68
+ s = Counter ()
68
69
for (end , (l , word )) in names_trie .iter (text ):
69
70
if profile [end ] in ['e' , 'x' ] and profile [end - l + 1 ] in ['b' , 'x' ]:
70
- s . add ( word )
71
+ s [ word ] += 1
71
72
return s
72
73
73
74
def find_datasets (self , text ):
@@ -105,30 +106,31 @@ def _init_structs(self, taxonomy):
105
106
106
107
107
108
@njit
108
- def axis_logprobs (evidences_for , reverse_probs , found_evidences , noise , pb ):
109
+ def axis_logprobs (evidences_for , reverse_probs , found_evidences , noise , pb , max_repetitions ):
109
110
logprob = 0.0
110
111
empty = typed .Dict .empty (types .unicode_type , types .float64 )
111
112
short_probs = reverse_probs .get (evidences_for , empty )
112
- for evidence in found_evidences :
113
- logprob += np .log (noise * pb + (1 - noise ) * short_probs .get (evidence , 0.0 ))
113
+ for evidence , count in found_evidences . items () :
114
+ logprob += min ( count , max_repetitions ) * np .log (noise * pb + (1 - noise ) * short_probs .get (evidence , 0.0 ))
114
115
return logprob
115
116
116
117
117
118
# compute log-probabilities in a given context and add them to logprobs
118
119
@njit
119
120
def compute_logprobs (taxonomy , tasks , datasets , metrics ,
120
121
reverse_merged_p , reverse_metrics_p , reverse_task_p ,
121
- dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs , axes_logprobs ):
122
+ dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs , axes_logprobs ,
123
+ max_repetitions ):
122
124
task_cache = typed .Dict .empty (types .unicode_type , types .float64 )
123
125
dataset_cache = typed .Dict .empty (types .unicode_type , types .float64 )
124
126
metric_cache = typed .Dict .empty (types .unicode_type , types .float64 )
125
127
for i , (task , dataset , metric ) in enumerate (taxonomy ):
126
128
if dataset not in dataset_cache :
127
- dataset_cache [dataset ] = axis_logprobs (dataset , reverse_merged_p , dss , noise , ds_pb )
129
+ dataset_cache [dataset ] = axis_logprobs (dataset , reverse_merged_p , dss , noise , ds_pb , 1 )
128
130
if metric not in metric_cache :
129
- metric_cache [metric ] = axis_logprobs (metric , reverse_metrics_p , mss , ms_noise , ms_pb )
131
+ metric_cache [metric ] = axis_logprobs (metric , reverse_metrics_p , mss , ms_noise , ms_pb , 1 )
130
132
if task not in task_cache :
131
- task_cache [task ] = axis_logprobs (task , reverse_task_p , tss , ts_noise , ts_pb )
133
+ task_cache [task ] = axis_logprobs (task , reverse_task_p , tss , ts_noise , ts_pb , max_repetitions )
132
134
133
135
logprobs [i ] += dataset_cache [dataset ] + metric_cache [metric ] + task_cache [task ]
134
136
for i , task in enumerate (tasks ):
@@ -149,7 +151,7 @@ def _to_typed_list(iterable):
149
151
150
152
151
153
class ContextSearch :
152
- def __init__ (self , taxonomy , evidence_finder , context_noise = (0.5 , 0.2 , 0.1 ), metrics_noise = None , task_noise = None ,
154
+ def __init__ (self , taxonomy , evidence_finder , context_noise = (0.5 , 0.1 , 0. 2 , 0.2 , 0. 1 ), metric_noise = None , task_noise = None ,
153
155
ds_pb = 0.001 , ms_pb = 0.01 , ts_pb = 0.01 , debug_gold_df = None ):
154
156
merged_p = \
155
157
get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in evidence_finder .datasets .items ()})[1 ]
@@ -169,7 +171,7 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), met
169
171
170
172
self .extract_acronyms = AcronymExtractor ()
171
173
self .context_noise = context_noise
172
- self .metrics_noise = metrics_noise if metrics_noise else context_noise
174
+ self .metrics_noise = metric_noise if metric_noise else context_noise
173
175
self .task_noise = task_noise if task_noise else context_noise
174
176
self .ds_pb = ds_pb
175
177
self .ms_pb = ms_pb
@@ -178,6 +180,7 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), met
178
180
self .reverse_metrics_p = self ._numba_update_nested_dict (reverse_probs (metrics_p ))
179
181
self .reverse_tasks_p = self ._numba_update_nested_dict (reverse_probs (tasks_p ))
180
182
self .debug_gold_df = debug_gold_df
183
+ self .max_repetitions = 1
181
184
182
185
def _numba_update_nested_dict (self , nested ):
183
186
d = typed .Dict ()
@@ -188,32 +191,43 @@ def _numba_update_nested_dict(self, nested):
188
191
return d
189
192
190
193
def _numba_extend_list (self , lst ):
191
- l = typed .List .empty_list (types .unicode_type )
194
+ l = typed .List .empty_list (( types .unicode_type , types . int32 ) )
192
195
for x in lst :
193
196
l .append (x )
194
197
return l
195
198
199
+ def _numba_extend_dict (self , dct ):
200
+ d = typed .Dict .empty (types .unicode_type , types .int64 )
201
+ d .update (dct )
202
+ return d
203
+
196
204
def compute_context_logprobs (self , context , noise , ms_noise , ts_noise , logprobs , axes_logprobs ):
197
- context = context or ""
198
- abbrvs = self .extract_acronyms (context )
199
- context = normalize_cell_ws (normalize_dataset_ws (context ))
200
- dss = set (self .evidence_finder .find_datasets (context )) | set (abbrvs .keys ())
201
- mss = set (self .evidence_finder .find_metrics (context ))
202
- tss = set (self .evidence_finder .find_tasks (context ))
203
- dss -= mss
204
- dss -= tss
205
- dss = [normalize_cell (ds ) for ds in dss ]
206
- mss = [normalize_cell (ms ) for ms in mss ]
207
- tss = [normalize_cell (ts ) for ts in tss ]
205
+ if isinstance (context , str ) or context is None :
206
+ context = context or ""
207
+ #abbrvs = self.extract_acronyms(context)
208
+ context = normalize_cell_ws (normalize_dataset_ws (context ))
209
+ #dss = set(self.evidence_finder.find_datasets(context)) | set(abbrvs.keys())
210
+ dss = self .evidence_finder .find_datasets (context )
211
+ mss = self .evidence_finder .find_metrics (context )
212
+ tss = self .evidence_finder .find_tasks (context )
213
+
214
+ dss -= mss
215
+ dss -= tss
216
+ else :
217
+ tss , dss , mss = context
218
+
219
+ dss = {normalize_cell (ds ): count for ds , count in dss .items ()}
220
+ mss = {normalize_cell (ms ): count for ms , count in mss .items ()}
221
+ tss = {normalize_cell (ts ): count for ts , count in tss .items ()}
208
222
###print("dss", dss)
209
223
###print("mss", mss)
210
- dss = self ._numba_extend_list (dss )
211
- mss = self ._numba_extend_list (mss )
212
- tss = self ._numba_extend_list (tss )
224
+ dss = self ._numba_extend_dict (dss )
225
+ mss = self ._numba_extend_dict (mss )
226
+ tss = self ._numba_extend_dict (tss )
213
227
compute_logprobs (self ._taxonomy , self ._taxonomy_tasks , self ._taxonomy_datasets , self ._taxonomy_metrics ,
214
228
self .reverse_merged_p , self .reverse_metrics_p , self .reverse_tasks_p ,
215
229
dss , mss , tss , noise , ms_noise , ts_noise , self .ds_pb , self .ms_pb , self .ts_pb , logprobs ,
216
- axes_logprobs )
230
+ axes_logprobs , self . max_repetitions )
217
231
218
232
def match (self , contexts ):
219
233
assert len (contexts ) == len (self .context_noise )
@@ -239,11 +253,16 @@ def match(self, contexts):
239
253
zip (self .taxonomy .metrics , axes_probs [2 ])
240
254
)
241
255
242
- def __call__ (self , query , datasets , caption , topk = 1 , debug_info = None ):
256
+ def __call__ (self , query , paper_context , abstract_context , table_context , caption , topk = 1 , debug_info = None ):
243
257
cellstr = debug_info .cell .cell_ext_id
244
- pipeline_logger ("linking::taxonomy_linking::call" , ext_id = cellstr , query = query , datasets = datasets , caption = caption )
245
- datasets = " " .join (datasets )
246
- key = (datasets , caption , query , topk )
258
+ pipeline_logger ("linking::taxonomy_linking::call" , ext_id = cellstr , query = query ,
259
+ paper_context = paper_context , abstract_context = abstract_context , table_context = table_context ,
260
+ caption = caption )
261
+
262
+ paper_hash = ";" .join ("," .join (s .elements ()) for s in paper_context )
263
+ abstract_hash = ";" .join ("," .join (s .elements ()) for s in abstract_context )
264
+ mentions_hash = ";" .join ("," .join (s .elements ()) for s in table_context )
265
+ key = (paper_hash , abstract_hash , mentions_hash , caption , query , topk )
247
266
###print(f"[DEBUG] {cellstr}")
248
267
###print("[DEBUG]", debug_info)
249
268
###print("query:", query, caption)
@@ -261,7 +280,7 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
261
280
###print("Taking result from cache")
262
281
p = self .queries [key ]
263
282
else :
264
- dists = self .match ((datasets , caption , query ))
283
+ dists = self .match ((paper_context , abstract_context , table_context , caption , query ))
265
284
266
285
all_top_results = [sorted (dist , key = lambda x : x [1 ], reverse = True )[:max (topk , 5 )] for dist in dists ]
267
286
top_results , top_results_t , top_results_d , top_results_m = all_top_results
@@ -279,7 +298,10 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
279
298
# task=top_results_t[i][0],
280
299
# dataset=top_results_d[i][0],
281
300
# metric=top_results_m[i][0])
282
- # best_independent.update({"evidence": "", "confidence": top_results_t[i][1]})
301
+ # best_independent.update({
302
+ # "evidence": "",
303
+ # "confidence": np.power(top_results_t[i][1] * top_results_d[i][1] * top_results_m[i][1], 1.0/3.0)
304
+ # })
283
305
# entries.append(best_independent)
284
306
#entries = [best_independent] + entries
285
307
@@ -314,18 +336,51 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
314
336
315
337
316
338
# todo: compare regex approach (old) with find_datasets(.) (current)
339
+ # todo: rename it
317
340
class DatasetExtractor :
318
341
def __init__ (self , evidence_finder ):
319
342
self .evidence_finder = evidence_finder
320
343
self .dataset_prefix_re = re .compile (r"[A-Z]|[a-z]+[A-Z]+|[0-9]" )
321
344
self .dataset_name_re = re .compile (r"\b(the)\b\s*(?P<name>((?!(the)\b)\w+\W+){1,10}?)(test|val(\.|idation)?|dev(\.|elopment)?|train(\.|ing)?\s+)?\bdata\s*set\b" , re .IGNORECASE )
322
345
346
+ def find_references (self , text , references ):
347
+ refs = r"\bxxref-(" + "|" .join ([re .escape (ref ) for ref in references ]) + r")\b"
348
+ return set (re .findall (refs , text ))
349
+
350
+ def get_table_contexts (self , paper , tables ):
351
+ ref_tables = [table for table in tables if table .figure_id ]
352
+ refs = [table .figure_id .replace ("." , "" ) for table in ref_tables ]
353
+ ref_contexts = {ref : [Counter (), Counter (), Counter ()] for ref in refs }
354
+ if hasattr (paper .text , "fragments" ):
355
+ for fragment in paper .text .fragments :
356
+ found_refs = self .find_references (fragment .text , refs )
357
+ if found_refs :
358
+ ts , ds , ms = self (fragment .header + "\n " + fragment .text )
359
+ for ref in found_refs :
360
+ ref_contexts [ref ][0 ] += ts
361
+ ref_contexts [ref ][1 ] += ds
362
+ ref_contexts [ref ][2 ] += ms
363
+ table_contexts = [
364
+ ref_contexts .get (
365
+ table .figure_id .replace ("." , "" ),
366
+ [Counter (), Counter (), Counter ()]
367
+ ) if table .figure_id else [Counter (), Counter (), Counter ()]
368
+ for table in tables
369
+ ]
370
+ return table_contexts
371
+
323
372
def from_paper (self , paper ):
324
- text = paper .text .abstract
373
+ abstract = paper .text .abstract
374
+ text = ""
325
375
if hasattr (paper .text , "fragments" ):
326
376
text += " " .join (f .text for f in paper .text .fragments )
327
- return self (text )
377
+ return self (text ), self ( abstract )
328
378
329
379
def __call__ (self , text ):
330
380
text = normalize_cell_ws (normalize_dataset_ws (text ))
331
- return self .evidence_finder .find_datasets (text ) | self .evidence_finder .find_tasks (text )
381
+ ds = self .evidence_finder .find_datasets (text )
382
+ ts = self .evidence_finder .find_tasks (text )
383
+ ms = self .evidence_finder .find_metrics (text )
384
+ ds -= ts
385
+ ds -= ms
386
+ return ts , ds , ms
0 commit comments