138
138
'LibriSpeech dev-other' : ['libri speech dev other' , 'libri speech' , 'dev' , 'other' , 'dev other' , 'development' , 'noisy' ],
139
139
})
140
140
141
+ tasks = {}
142
+
141
143
# escaped_ws_re = re.compile(r'\\\s+')
142
144
# def name_to_re(name):
143
145
# return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
144
146
145
147
#all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
146
148
all_datasets = set (normalize_cell_ws (normalize_dataset (y )) for x in datasets .values () for y in x )
147
149
all_metrics = set (normalize_cell_ws (y ) for x in metrics .values () for y in x )
150
+ all_tasks = set (normalize_cell_ws (normalize_dataset (y )) for x in tasks .values () for y in x )
151
+
148
152
#all_metrics = set(metrics_p.keys())
149
153
150
154
# all_datasets_re = {x:name_to_re(x) for x in all_datasets}
@@ -187,6 +191,7 @@ def find_names(text, names_trie):
187
191
188
192
all_datasets_trie = make_trie (all_datasets )
189
193
all_metrics_trie = make_trie (all_metrics )
194
+ all_tasks_trie = make_trie (all_tasks )
190
195
191
196
192
197
def find_datasets (text ):
@@ -195,18 +200,23 @@ def find_datasets(text):
195
200
def find_metrics (text ):
196
201
return find_names (text , all_metrics_trie )
197
202
203
+ def find_tasks (text ):
204
+ return find_names (text , all_tasks_trie )
205
+
198
206
def dummy_item (reason ):
199
207
return pd .DataFrame (dict (dataset = [reason ], task = [reason ], metric = [reason ], evidence = ["" ], confidence = [0.0 ]))
200
208
201
209
202
210
203
211
@njit
204
- def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , dss , mss , noise , ms_noise , ds_pb , ms_pb , logprobs ):
212
+ def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , reverse_task_p ,
213
+ dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs ):
205
214
empty = typed .Dict .empty (types .unicode_type , types .float64 )
206
215
for i , (task , dataset , metric ) in enumerate (taxonomy ):
207
216
logprob = 0.0
208
217
short_probs = reverse_merged_p .get (dataset , empty )
209
218
met_probs = reverse_metrics_p .get (metric , empty )
219
+ task_probs = reverse_task_p .get (task , empty )
210
220
for ds in dss :
211
221
# for abbrv, long_form in abbrvs.items():
212
222
# if ds == abbrv:
@@ -216,17 +226,21 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, no
216
226
logprob += np .log (noise * ds_pb + (1 - noise ) * short_probs .get (ds , 0.0 ))
217
227
for ms in mss :
218
228
logprob += np .log (ms_noise * ms_pb + (1 - ms_noise ) * met_probs .get (ms , 0.0 ))
229
+ for ts in tss :
230
+ logprob += np .log (ts_noise * ts_pb + (1 - ts_noise ) * task_probs .get (ts , 0.0 ))
219
231
logprobs [i ] += logprob
220
232
#logprobs[(dataset, metric)] = logprob
221
233
222
234
223
235
class ContextSearch :
224
- def __init__ (self , taxonomy , context_noise = (0.5 , 0.2 , 0.1 ), metrics_noise = None , ds_pb = 0.001 , ms_pb = 0.01 , debug_gold_df = None ):
236
+ def __init__ (self , taxonomy , context_noise = (0.5 , 0.2 , 0.1 ), metrics_noise = None , task_noise = None ,
237
+ ds_pb = 0.001 , ms_pb = 0.01 , ts_pb = 0.01 , debug_gold_df = None ):
225
238
merged_p = \
226
239
get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in datasets .items ()})[1 ]
227
240
metrics_p = \
228
241
get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in metrics .items ()})[1 ]
229
-
242
+ tasks_p = \
243
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in tasks .items ()})[1 ]
230
244
231
245
self .queries = {}
232
246
self .taxonomy = taxonomy
@@ -236,10 +250,13 @@ def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None,
236
250
self .extract_acronyms = AcronymExtractor ()
237
251
self .context_noise = context_noise
238
252
self .metrics_noise = metrics_noise if metrics_noise else context_noise
253
+ self .task_noise = task_noise if task_noise else context_noise
239
254
self .ds_pb = ds_pb
240
255
self .ms_pb = ms_pb
256
+ self .ts_pb = ts_pb
241
257
self .reverse_merged_p = self ._numba_update_nested_dict (reverse_probs (merged_p ))
242
258
self .reverse_metrics_p = self ._numba_update_nested_dict (reverse_probs (metrics_p ))
259
+ self .reverse_tasks_p = self ._numba_update_nested_dict (reverse_probs (tasks_p ))
243
260
self .debug_gold_df = debug_gold_df
244
261
245
262
def _numba_update_nested_dict (self , nested ):
@@ -256,29 +273,33 @@ def _numba_extend_list(self, lst):
256
273
l .append (x )
257
274
return l
258
275
259
- def compute_context_logprobs (self , context , noise , ms_noise , logprobs ):
276
+ def compute_context_logprobs (self , context , noise , ms_noise , ts_noise , logprobs ):
260
277
context = context or ""
261
278
abbrvs = self .extract_acronyms (context )
262
279
context = normalize_cell_ws (normalize_dataset (context ))
263
280
dss = set (find_datasets (context )) | set (abbrvs .keys ())
264
281
mss = set (find_metrics (context ))
282
+ tss = set (find_tasks (context ))
265
283
dss -= mss
284
+ dss -= tss
266
285
dss = [normalize_cell (ds ) for ds in dss ]
267
286
mss = [normalize_cell (ms ) for ms in mss ]
287
+ tss = [normalize_cell (ts ) for ts in tss ]
268
288
###print("dss", dss)
269
289
###print("mss", mss)
270
290
dss = self ._numba_extend_list (dss )
271
291
mss = self ._numba_extend_list (mss )
272
- compute_logprobs (self ._taxonomy , self .reverse_merged_p , self .reverse_metrics_p ,
273
- dss , mss , noise , ms_noise , self .ds_pb , self .ms_pb , logprobs )
292
+ tss = self ._numba_extend_list (tss )
293
+ compute_logprobs (self ._taxonomy , self .reverse_merged_p , self .reverse_metrics_p , self .reverse_tasks_p ,
294
+ dss , mss , tss , noise , ms_noise , ts_noise , self .ds_pb , self .ms_pb , self .ts_pb , logprobs )
274
295
275
296
def match (self , contexts ):
276
297
assert len (contexts ) == len (self .context_noise )
277
298
n = len (self ._taxonomy )
278
299
context_logprobs = np .zeros (n )
279
300
280
- for context , noise , ms_noise in zip (contexts , self .context_noise , self .metrics_noise ):
281
- self .compute_context_logprobs (context , noise , ms_noise , context_logprobs )
301
+ for context , noise , ms_noise , ts_noise in zip (contexts , self .context_noise , self .metrics_noise , self . task_noise ):
302
+ self .compute_context_logprobs (context , noise , ms_noise , ts_noise , context_logprobs )
282
303
keys = self .taxonomy .taxonomy
283
304
logprobs = context_logprobs
284
305
#keys, logprobs = zip(*context_logprobs.items())
0 commit comments