3
3
4
4
from sota_extractor2 .models .linking .acronym_extractor import AcronymExtractor
5
5
from sota_extractor2 .models .linking .probs import get_probs , reverse_probs
6
- from sota_extractor2 .models .linking .utils import normalize_dataset , normalize_cell , normalize_cell_ws
6
+ from sota_extractor2 .models .linking .utils import normalize_dataset_ws , normalize_cell , normalize_cell_ws
7
7
from scipy .special import softmax
8
8
import re
9
9
import pandas as pd
@@ -201,9 +201,9 @@ def dummy_item(reason):
201
201
202
202
203
203
@njit
204
- def compute_logprobs (dataset_metric , reverse_merged_p , reverse_metrics_p , dss , mss , noise , logprobs ):
204
+ def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , dss , mss , noise , logprobs ):
205
205
empty = typed .Dict .empty (types .unicode_type , types .float64 )
206
- for i , (dataset , metric ) in enumerate (dataset_metric ):
206
+ for i , (task , dataset , metric ) in enumerate (taxonomy ):
207
207
logprob = 0.0
208
208
short_probs = reverse_merged_p .get (dataset , empty )
209
209
met_probs = reverse_metrics_p .get (metric , empty )
@@ -223,16 +223,16 @@ def compute_logprobs(dataset_metric, reverse_merged_p, reverse_metrics_p, dss, m
223
223
class ContextSearch :
224
224
def __init__ (self , taxonomy , context_noise = (0.5 , 0.2 , 0.1 ), debug_gold_df = None ):
225
225
merged_p = \
226
- get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in datasets .items ()})[1 ]
226
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset_ws (x )) for x in v ]) for k , v in datasets .items ()})[1 ]
227
227
metrics_p = \
228
- get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in metrics .items ()})[1 ]
228
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset_ws (x )) for x in v ]) for k , v in metrics .items ()})[1 ]
229
229
230
230
231
231
self .queries = {}
232
232
self .taxonomy = taxonomy
233
- self ._dataset_metric = typed .List ()
233
+ self ._taxonomy = typed .List ()
234
234
for t in self .taxonomy .taxonomy :
235
- self ._dataset_metric .append (t )
235
+ self ._taxonomy .append (t )
236
236
self .extract_acronyms = AcronymExtractor ()
237
237
self .context_noise = context_noise
238
238
self .reverse_merged_p = self ._numba_update_nested_dict (reverse_probs (merged_p ))
@@ -254,8 +254,9 @@ def _numba_extend_list(self, lst):
254
254
return l
255
255
256
256
def compute_context_logprobs (self , context , noise , logprobs ):
257
+ context = context or ""
257
258
abbrvs = self .extract_acronyms (context )
258
- context = normalize_cell_ws (normalize_dataset (context ))
259
+ context = normalize_cell_ws (normalize_dataset_ws (context ))
259
260
dss = set (find_datasets (context )) | set (abbrvs .keys ())
260
261
mss = set (find_metrics (context ))
261
262
dss -= mss
@@ -265,16 +266,16 @@ def compute_context_logprobs(self, context, noise, logprobs):
265
266
###print("mss", mss)
266
267
dss = self ._numba_extend_list (dss )
267
268
mss = self ._numba_extend_list (mss )
268
- compute_logprobs (self ._dataset_metric , self .reverse_merged_p , self .reverse_metrics_p , dss , mss , noise , logprobs )
269
+ compute_logprobs (self ._taxonomy , self .reverse_merged_p , self .reverse_metrics_p , dss , mss , noise , logprobs )
269
270
270
271
def match (self , contexts ):
271
272
assert len (contexts ) == len (self .context_noise )
272
- n = len (self ._dataset_metric )
273
- context_logprobs = np .ones (n )
273
+ n = len (self ._taxonomy )
274
+ context_logprobs = np .zeros (n )
274
275
275
276
for context , noise in zip (contexts , self .context_noise ):
276
277
self .compute_context_logprobs (context , noise , context_logprobs )
277
- keys = self .taxonomy .taxonomy . keys ()
278
+ keys = self .taxonomy .taxonomy
278
279
logprobs = context_logprobs
279
280
#keys, logprobs = zip(*context_logprobs.items())
280
281
probs = softmax (np .array (logprobs ))
@@ -290,12 +291,12 @@ def __call__(self, query, datasets, caption, debug_info=None):
290
291
###print("query:", query, caption)
291
292
if key in self .queries :
292
293
# print(self.queries[key])
293
- for context in key :
294
- abbrvs = self .extract_acronyms (context )
295
- context = normalize_cell_ws (normalize_dataset (context ))
296
- dss = set (find_datasets (context )) | set (abbrvs .keys ())
297
- mss = set (find_metrics (context ))
298
- dss -= mss
294
+ # for context in key:
295
+ # abbrvs = self.extract_acronyms(context)
296
+ # context = normalize_cell_ws(normalize_dataset_ws (context))
297
+ # dss = set(find_datasets(context)) | set(abbrvs.keys())
298
+ # mss = set(find_metrics(context))
299
+ # dss -= mss
299
300
###print("dss", dss)
300
301
###print("mss", mss)
301
302
@@ -307,7 +308,8 @@ def __call__(self, query, datasets, caption, debug_info=None):
307
308
308
309
entries = []
309
310
for it , prob in topk :
310
- entry = dict (self .taxonomy .taxonomy [it ])
311
+ task , dataset , metric = it
312
+ entry = dict (task = task , dataset = dataset , metric = metric )
311
313
entry .update ({"evidence" : "" , "confidence" : prob })
312
314
entries .append (entry )
313
315
@@ -351,4 +353,4 @@ def from_paper(self, paper):
351
353
return self (text )
352
354
353
355
def __call__ (self , text ):
354
- return find_datasets (normalize_cell_ws (normalize_dataset (text )))
356
+ return find_datasets (normalize_cell_ws (normalize_dataset_ws (text )))
0 commit comments