3
3
4
4
from sota_extractor2 .models .linking .acronym_extractor import AcronymExtractor
5
5
from sota_extractor2 .models .linking .probs import get_probs , reverse_probs
6
- from sota_extractor2 .models .linking .utils import normalize_dataset_ws , normalize_cell , normalize_cell_ws
6
+ from sota_extractor2 .models .linking .utils import normalize_dataset , normalize_cell , normalize_cell_ws
7
7
from scipy .special import softmax
8
8
import re
9
9
import pandas as pd
108
108
'Rain100L' : ['rain100l' ],
109
109
'Rain12' : ['rain12' ],
110
110
'Rain800' : ['rain800' ],
111
- 'Rain1400' : ['rain1400' ],
112
- 'Real Rain' : ['real rain' ],
113
- 'Rain in Surveillance' : ['ris' ],
114
- 'Rain in Driving' : ['rid' ],
111
+ 'Rain1400' : ['rain1400' ],
112
+ 'Real Rain' : ['real rain' ],
113
+ 'Rain in Surveillance' : ['ris' ],
114
+ 'Rain in Driving' : ['rid' ],
115
115
'DID-MDN' : ['did-mdn' ],
116
116
'SOTS' : ['sots' ],
117
117
'Test 1' : ['test 1' ],
118
118
'RainSynLight25' : ['rainsynlight25' ],
119
- 'RainSynComplex25' : ['rainsyncomplex25' ],
120
- 'NTURain' : ['nturain' ],
119
+ 'RainSynComplex25' : ['rainsyncomplex25' ],
120
+ 'NTURain' : ['nturain' ],
121
121
'RainSynAll100' : ['rainsynall100' ],
122
122
'SPA-DATA' : ['spa-data' ],
123
123
'LasVR' : ['lasvar' ],
143
143
# return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
144
144
145
145
#all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
146
- all_datasets = set (y for x in datasets .values () for y in x )
147
- all_metrics = set (y for x in metrics .values () for y in x )
146
+ all_datasets = set (normalize_cell_ws ( normalize_dataset ( y )) for x in datasets .values () for y in x )
147
+ all_metrics = set (normalize_cell_ws ( y ) for x in metrics .values () for y in x )
148
148
#all_metrics = set(metrics_p.keys())
149
149
150
150
# all_datasets_re = {x:name_to_re(x) for x in all_datasets}
@@ -201,7 +201,7 @@ def dummy_item(reason):
201
201
202
202
203
203
@njit
204
- def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , dss , mss , noise , logprobs ):
204
+ def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , dss , mss , noise , ms_noise , ds_pb , ms_pb , logprobs ):
205
205
empty = typed .Dict .empty (types .unicode_type , types .float64 )
206
206
for i , (task , dataset , metric ) in enumerate (taxonomy ):
207
207
logprob = 0.0
@@ -213,19 +213,19 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, no
213
213
# ds = long_form
214
214
# break
215
215
# if merged_p[ds].get('NOMATCH', 0.0) < 0.5:
216
- logprob += np .log (noise * 0.001 + (1 - noise ) * short_probs .get (ds , 0.0 ))
216
+ logprob += np .log (noise * ds_pb + (1 - noise ) * short_probs .get (ds , 0.0 ))
217
217
for ms in mss :
218
- logprob += np .log (noise * 0.01 + (1 - noise ) * met_probs .get (ms , 0.0 ))
218
+ logprob += np .log (ms_noise * ms_pb + (1 - ms_noise ) * met_probs .get (ms , 0.0 ))
219
219
logprobs [i ] += logprob
220
220
#logprobs[(dataset, metric)] = logprob
221
221
222
222
223
223
class ContextSearch :
224
- def __init__ (self , taxonomy , context_noise = (0.5 , 0.2 , 0.1 ), debug_gold_df = None ):
224
+ def __init__ (self , taxonomy , context_noise = (0.5 , 0.2 , 0.1 ), metrics_noise = None , ds_pb = 0.001 , ms_pb = 0.01 , debug_gold_df = None ):
225
225
merged_p = \
226
- get_probs ({k : Counter ([normalize_cell (normalize_dataset_ws (x )) for x in v ]) for k , v in datasets .items ()})[1 ]
226
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in datasets .items ()})[1 ]
227
227
metrics_p = \
228
- get_probs ({k : Counter ([normalize_cell (normalize_dataset_ws (x )) for x in v ]) for k , v in metrics .items ()})[1 ]
228
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in metrics .items ()})[1 ]
229
229
230
230
231
231
self .queries = {}
@@ -235,6 +235,9 @@ def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), debug_gold_df=None):
235
235
self ._taxonomy .append (t )
236
236
self .extract_acronyms = AcronymExtractor ()
237
237
self .context_noise = context_noise
238
+ self .metrics_noise = metrics_noise if metrics_noise else context_noise
239
+ self .ds_pb = ds_pb
240
+ self .ms_pb = ms_pb
238
241
self .reverse_merged_p = self ._numba_update_nested_dict (reverse_probs (merged_p ))
239
242
self .reverse_metrics_p = self ._numba_update_nested_dict (reverse_probs (metrics_p ))
240
243
self .debug_gold_df = debug_gold_df
@@ -253,10 +256,10 @@ def _numba_extend_list(self, lst):
253
256
l .append (x )
254
257
return l
255
258
256
- def compute_context_logprobs (self , context , noise , logprobs ):
259
+ def compute_context_logprobs (self , context , noise , ms_noise , logprobs ):
257
260
context = context or ""
258
261
abbrvs = self .extract_acronyms (context )
259
- context = normalize_cell_ws (normalize_dataset_ws (context ))
262
+ context = normalize_cell_ws (normalize_dataset (context ))
260
263
dss = set (find_datasets (context )) | set (abbrvs .keys ())
261
264
mss = set (find_metrics (context ))
262
265
dss -= mss
@@ -266,15 +269,16 @@ def compute_context_logprobs(self, context, noise, logprobs):
266
269
###print("mss", mss)
267
270
dss = self ._numba_extend_list (dss )
268
271
mss = self ._numba_extend_list (mss )
269
- compute_logprobs (self ._taxonomy , self .reverse_merged_p , self .reverse_metrics_p , dss , mss , noise , logprobs )
272
+ compute_logprobs (self ._taxonomy , self .reverse_merged_p , self .reverse_metrics_p ,
273
+ dss , mss , noise , ms_noise , self .ds_pb , self .ms_pb , logprobs )
270
274
271
275
def match (self , contexts ):
272
276
assert len (contexts ) == len (self .context_noise )
273
277
n = len (self ._taxonomy )
274
278
context_logprobs = np .zeros (n )
275
279
276
- for context , noise in zip (contexts , self .context_noise ):
277
- self .compute_context_logprobs (context , noise , context_logprobs )
280
+ for context , noise , ms_noise in zip (contexts , self .context_noise , self . metrics_noise ):
281
+ self .compute_context_logprobs (context , noise , ms_noise , context_logprobs )
278
282
keys = self .taxonomy .taxonomy
279
283
logprobs = context_logprobs
280
284
#keys, logprobs = zip(*context_logprobs.items())
@@ -293,7 +297,7 @@ def __call__(self, query, datasets, caption, debug_info=None):
293
297
# print(self.queries[key])
294
298
# for context in key:
295
299
# abbrvs = self.extract_acronyms(context)
296
- # context = normalize_cell_ws(normalize_dataset_ws (context))
300
+ # context = normalize_cell_ws(normalize_dataset (context))
297
301
# dss = set(find_datasets(context)) | set(abbrvs.keys())
298
302
# mss = set(find_metrics(context))
299
303
# dss -= mss
@@ -353,4 +357,4 @@ def from_paper(self, paper):
353
357
return self (text )
354
358
355
359
def __call__ (self , text ):
356
- return find_datasets (normalize_cell_ws (normalize_dataset_ws (text )))
360
+ return find_datasets (normalize_cell_ws (normalize_dataset (text )))
0 commit comments