Skip to content

Commit 15cfa0b

Browse files
committed
Add support for tasks to ContextSearch
1 parent 92cc525 commit 15cfa0b

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

sota_extractor2/models/linking/context_search.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,17 @@
138138
'LibriSpeech dev-other': ['libri speech dev other', 'libri speech', 'dev', 'other', 'dev other', 'development', 'noisy'],
139139
})
140140

141+
tasks = {}
142+
141143
# escaped_ws_re = re.compile(r'\\\s+')
142144
# def name_to_re(name):
143145
# return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
144146

145147
#all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
146148
all_datasets = set(normalize_cell_ws(normalize_dataset(y)) for x in datasets.values() for y in x)
147149
all_metrics = set(normalize_cell_ws(y) for x in metrics.values() for y in x)
150+
all_tasks = set(normalize_cell_ws(normalize_dataset(y)) for x in tasks.values() for y in x)
151+
148152
#all_metrics = set(metrics_p.keys())
149153

150154
# all_datasets_re = {x:name_to_re(x) for x in all_datasets}
@@ -187,6 +191,7 @@ def find_names(text, names_trie):
187191

188192
all_datasets_trie = make_trie(all_datasets)
189193
all_metrics_trie = make_trie(all_metrics)
194+
all_tasks_trie = make_trie(all_tasks)
190195

191196

192197
def find_datasets(text):
@@ -195,18 +200,23 @@ def find_datasets(text):
195200
def find_metrics(text):
196201
return find_names(text, all_metrics_trie)
197202

203+
def find_tasks(text):
204+
return find_names(text, all_tasks_trie)
205+
198206
def dummy_item(reason):
199207
return pd.DataFrame(dict(dataset=[reason], task=[reason], metric=[reason], evidence=[""], confidence=[0.0]))
200208

201209

202210

203211
@njit
204-
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, noise, ms_noise, ds_pb, ms_pb, logprobs):
212+
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task_p,
213+
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs):
205214
empty = typed.Dict.empty(types.unicode_type, types.float64)
206215
for i, (task, dataset, metric) in enumerate(taxonomy):
207216
logprob = 0.0
208217
short_probs = reverse_merged_p.get(dataset, empty)
209218
met_probs = reverse_metrics_p.get(metric, empty)
219+
task_probs = reverse_task_p.get(task, empty)
210220
for ds in dss:
211221
# for abbrv, long_form in abbrvs.items():
212222
# if ds == abbrv:
@@ -216,17 +226,21 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, no
216226
logprob += np.log(noise * ds_pb + (1 - noise) * short_probs.get(ds, 0.0))
217227
for ms in mss:
218228
logprob += np.log(ms_noise * ms_pb + (1 - ms_noise) * met_probs.get(ms, 0.0))
229+
for ts in tss:
230+
logprob += np.log(ts_noise * ts_pb + (1 - ts_noise) * task_probs.get(ts, 0.0))
219231
logprobs[i] += logprob
220232
#logprobs[(dataset, metric)] = logprob
221233

222234

223235
class ContextSearch:
224-
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None, ds_pb=0.001, ms_pb=0.01, debug_gold_df=None):
236+
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None, task_noise=None,
237+
ds_pb=0.001, ms_pb=0.01, ts_pb=0.01, debug_gold_df=None):
225238
merged_p = \
226239
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in datasets.items()})[1]
227240
metrics_p = \
228241
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in metrics.items()})[1]
229-
242+
tasks_p = \
243+
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in tasks.items()})[1]
230244

231245
self.queries = {}
232246
self.taxonomy = taxonomy
@@ -236,10 +250,13 @@ def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None,
236250
self.extract_acronyms = AcronymExtractor()
237251
self.context_noise = context_noise
238252
self.metrics_noise = metrics_noise if metrics_noise else context_noise
253+
self.task_noise = task_noise if task_noise else context_noise
239254
self.ds_pb = ds_pb
240255
self.ms_pb = ms_pb
256+
self.ts_pb = ts_pb
241257
self.reverse_merged_p = self._numba_update_nested_dict(reverse_probs(merged_p))
242258
self.reverse_metrics_p = self._numba_update_nested_dict(reverse_probs(metrics_p))
259+
self.reverse_tasks_p = self._numba_update_nested_dict(reverse_probs(tasks_p))
243260
self.debug_gold_df = debug_gold_df
244261

245262
def _numba_update_nested_dict(self, nested):
@@ -256,29 +273,33 @@ def _numba_extend_list(self, lst):
256273
l.append(x)
257274
return l
258275

259-
def compute_context_logprobs(self, context, noise, ms_noise, logprobs):
276+
def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs):
260277
context = context or ""
261278
abbrvs = self.extract_acronyms(context)
262279
context = normalize_cell_ws(normalize_dataset(context))
263280
dss = set(find_datasets(context)) | set(abbrvs.keys())
264281
mss = set(find_metrics(context))
282+
tss = set(find_tasks(context))
265283
dss -= mss
284+
dss -= tss
266285
dss = [normalize_cell(ds) for ds in dss]
267286
mss = [normalize_cell(ms) for ms in mss]
287+
tss = [normalize_cell(ts) for ts in tss]
268288
###print("dss", dss)
269289
###print("mss", mss)
270290
dss = self._numba_extend_list(dss)
271291
mss = self._numba_extend_list(mss)
272-
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p,
273-
dss, mss, noise, ms_noise, self.ds_pb, self.ms_pb, logprobs)
292+
tss = self._numba_extend_list(tss)
293+
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
294+
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs)
274295

275296
def match(self, contexts):
276297
assert len(contexts) == len(self.context_noise)
277298
n = len(self._taxonomy)
278299
context_logprobs = np.zeros(n)
279300

280-
for context, noise, ms_noise in zip(contexts, self.context_noise, self.metrics_noise):
281-
self.compute_context_logprobs(context, noise, ms_noise, context_logprobs)
301+
for context, noise, ms_noise, ts_noise in zip(contexts, self.context_noise, self.metrics_noise, self.task_noise):
302+
self.compute_context_logprobs(context, noise, ms_noise, ts_noise, context_logprobs)
282303
keys = self.taxonomy.taxonomy
283304
logprobs = context_logprobs
284305
#keys, logprobs = zip(*context_logprobs.items())

0 commit comments

Comments
 (0)