8
8
import re
9
9
import pandas as pd
10
10
import numpy as np
11
+ import ahocorasick
12
+ from numba import njit , typed , types
11
13
12
14
from sota_extractor2 .pipeline_logger import pipeline_logger
13
15
136
138
'LibriSpeech dev-other' : ['libri speech dev other' , 'libri speech' , 'dev' , 'other' , 'dev other' , 'development' , 'noisy' ],
137
139
})
138
140
139
- escaped_ws_re = re .compile (r'\\\s+' )
140
- def name_to_re (name ):
141
- return re .compile (r'(?:^|\s+)' + escaped_ws_re .sub (r'\\s*' , re .escape (name .strip ())) + r'(?:$|\s+)' , re .I )
141
+ # escaped_ws_re = re.compile(r'\\\s+')
142
+ # def name_to_re(name):
143
+ # return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
142
144
143
145
#all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
144
146
all_datasets = set (y for x in datasets .values () for y in x )
145
147
all_metrics = set (y for x in metrics .values () for y in x )
146
148
#all_metrics = set(metrics_p.keys())
147
149
148
- all_datasets_re = {x :name_to_re (x ) for x in all_datasets }
149
- all_metrics_re = {x :name_to_re (x ) for x in all_metrics }
150
+ # all_datasets_re = {x:name_to_re(x) for x in all_datasets}
151
+ # all_metrics_re = {x:name_to_re(x) for x in all_metrics}
150
152
#all_datasets = set(x for v in merged_p.values() for x in v)
151
153
152
- def find_names (text , names_re ):
153
- return set (name for name , name_re in names_re .items () if name_re .search (text ))
154
+ # def find_names(text, names_re):
155
+ # return set(name for name, name_re in names_re.items() if name_re.search(text))
156
+
157
+
158
+ def make_trie (names ):
159
+ trie = ahocorasick .Automaton ()
160
+ for name in names :
161
+ norm = name .replace (" " , "" )
162
+ trie .add_word (norm , (len (norm ), name ))
163
+ trie .make_automaton ()
164
+ return trie
165
+
166
+
167
+ single_letter_re = re .compile (r"\b\w\b" )
168
+ init_letter_re = re .compile (r"\b\w" )
169
+ end_letter_re = re .compile (r"\w\b" )
170
+ letter_re = re .compile (r"\w" )
171
+
172
+
173
+ def find_names (text , names_trie ):
174
+ text = text .lower ()
175
+ profile = letter_re .sub ("i" , text )
176
+ profile = init_letter_re .sub ("b" , profile )
177
+ profile = end_letter_re .sub ("e" , profile )
178
+ profile = single_letter_re .sub ("x" , profile )
179
+ text = text .replace (" " , "" )
180
+ profile = profile .replace (" " , "" )
181
+ s = set ()
182
+ for (end , (l , word )) in names_trie .iter (text ):
183
+ if profile [end ] in ['e' , 'x' ] and profile [end - l + 1 ] in ['b' , 'x' ]:
184
+ s .add (word )
185
+ return s
186
+
187
+
188
+ all_datasets_trie = make_trie (all_datasets )
189
+ all_metrics_trie = make_trie (all_metrics )
190
+
154
191
155
192
def find_datasets (text ):
156
- return find_names (text , all_datasets_re )
193
+ return find_names (text , all_datasets_trie )
157
194
158
195
def find_metrics (text ):
159
- return find_names (text , all_metrics_re )
196
+ return find_names (text , all_metrics_trie )
160
197
161
198
def dummy_item (reason ):
162
199
return pd .DataFrame (dict (dataset = [reason ], task = [reason ], metric = [reason ], evidence = ["" ], confidence = [0.0 ]))
163
200
164
201
165
202
203
+ @njit
204
+ def compute_logprobs (dataset_metric , reverse_merged_p , reverse_metrics_p , dss , mss , noise , logprobs ):
205
+ empty = typed .Dict .empty (types .unicode_type , types .float64 )
206
+ for i , (dataset , metric ) in enumerate (dataset_metric ):
207
+ logprob = 0.0
208
+ short_probs = reverse_merged_p .get (dataset , empty )
209
+ met_probs = reverse_metrics_p .get (metric , empty )
210
+ for ds in dss :
211
+ # for abbrv, long_form in abbrvs.items():
212
+ # if ds == abbrv:
213
+ # ds = long_form
214
+ # break
215
+ # if merged_p[ds].get('NOMATCH', 0.0) < 0.5:
216
+ logprob += np .log (noise * 0.001 + (1 - noise ) * short_probs .get (ds , 0.0 ))
217
+ for ms in mss :
218
+ logprob += np .log (noise * 0.01 + (1 - noise ) * met_probs .get (ms , 0.0 ))
219
+ logprobs [i ] += logprob
220
+ #logprobs[(dataset, metric)] = logprob
221
+
166
222
167
223
class ContextSearch :
168
224
def __init__ (self , taxonomy , context_noise = (0.5 , 0.2 , 0.1 ), debug_gold_df = None ):
@@ -174,47 +230,53 @@ def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), debug_gold_df=None):
174
230
175
231
self .queries = {}
176
232
self .taxonomy = taxonomy
233
+ self ._dataset_metric = typed .List ()
234
+ for t in self .taxonomy .taxonomy :
235
+ self ._dataset_metric .append (t )
177
236
self .extract_acronyms = AcronymExtractor ()
178
237
self .context_noise = context_noise
179
- self .reverse_merged_p = reverse_probs (merged_p )
180
- self .reverse_metrics_p = reverse_probs (metrics_p )
238
+ self .reverse_merged_p = self . _numba_update_nested_dict ( reverse_probs (merged_p ) )
239
+ self .reverse_metrics_p = self . _numba_update_nested_dict ( reverse_probs (metrics_p ) )
181
240
self .debug_gold_df = debug_gold_df
182
241
183
- def compute_logprobs (self , dss , mss , abbrvs , noise , logprobs ):
184
- for dataset , metric in self .taxonomy .taxonomy :
185
- logprob = logprobs .get ((dataset , metric ), 1.0 )
186
- short_probs = self .reverse_merged_p .get (dataset , {})
187
- met_probs = self .reverse_metrics_p .get (metric , {})
188
- for ds in dss :
189
- ds = normalize_cell (ds )
190
- # for abbrv, long_form in abbrvs.items():
191
- # if ds == abbrv:
192
- # ds = long_form
193
- # break
194
- # if merged_p[ds].get('NOMATCH', 0.0) < 0.5:
195
- logprob += np .log (noise * 0.001 + (1 - noise ) * short_probs .get (ds , 0.0 ))
196
- for ms in mss :
197
- ms = normalize_cell (ms )
198
- logprob += np .log (noise * 0.01 + (1 - noise ) * met_probs .get (ms , 0.0 ))
199
- logprobs [(dataset , metric )] = logprob
242
+ def _numba_update_nested_dict (self , nested ):
243
+ d = typed .Dict ()
244
+ for key , dct in nested .items ():
245
+ d2 = typed .Dict ()
246
+ d2 .update (dct )
247
+ d [key ] = d2
248
+ return d
249
+
250
+ def _numba_extend_list (self , lst ):
251
+ l = typed .List .empty_list (types .unicode_type )
252
+ for x in lst :
253
+ l .append (x )
254
+ return l
200
255
201
256
def compute_context_logprobs (self , context , noise , logprobs ):
202
257
abbrvs = self .extract_acronyms (context )
203
258
context = normalize_cell_ws (normalize_dataset (context ))
204
259
dss = set (find_datasets (context )) | set (abbrvs .keys ())
205
260
mss = set (find_metrics (context ))
206
261
dss -= mss
262
+ dss = [normalize_cell (ds ) for ds in dss ]
263
+ mss = [normalize_cell (ms ) for ms in mss ]
207
264
###print("dss", dss)
208
265
###print("mss", mss)
209
- self .compute_logprobs (dss , mss , abbrvs , noise , logprobs )
266
+ dss = self ._numba_extend_list (dss )
267
+ mss = self ._numba_extend_list (mss )
268
+ compute_logprobs (self ._dataset_metric , self .reverse_merged_p , self .reverse_metrics_p , dss , mss , noise , logprobs )
210
269
211
270
def match (self , contexts ):
212
271
assert len (contexts ) == len (self .context_noise )
213
- context_logprobs = {}
272
+ n = len (self ._dataset_metric )
273
+ context_logprobs = np .ones (n )
214
274
215
275
for context , noise in zip (contexts , self .context_noise ):
216
276
self .compute_context_logprobs (context , noise , context_logprobs )
217
- keys , logprobs = zip (* context_logprobs .items ())
277
+ keys = self .taxonomy .taxonomy .keys ()
278
+ logprobs = context_logprobs
279
+ #keys, logprobs = zip(*context_logprobs.items())
218
280
probs = softmax (np .array (logprobs ))
219
281
return zip (keys , probs )
220
282
0 commit comments