Skip to content

Commit a69176e

Browse files
author
Marcin Kardas
committed
Use njit and tries to speed up context search
1 parent c763bc9 commit a69176e

File tree

1 file changed

+93
-31
lines changed

1 file changed

+93
-31
lines changed

sota_extractor2/models/linking/context_search.py

Lines changed: 93 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import re
99
import pandas as pd
1010
import numpy as np
11+
import ahocorasick
12+
from numba import njit, typed, types
1113

1214
from sota_extractor2.pipeline_logger import pipeline_logger
1315

@@ -136,33 +138,87 @@
136138
'LibriSpeech dev-other': ['libri speech dev other', 'libri speech', 'dev', 'other', 'dev other', 'development', 'noisy'],
137139
})
138140

139-
escaped_ws_re = re.compile(r'\\\s+')
140-
def name_to_re(name):
141-
return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
141+
# escaped_ws_re = re.compile(r'\\\s+')
142+
# def name_to_re(name):
143+
# return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
142144

143145
#all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
144146
all_datasets = set(y for x in datasets.values() for y in x)
145147
all_metrics = set(y for x in metrics.values() for y in x)
146148
#all_metrics = set(metrics_p.keys())
147149

148-
all_datasets_re = {x:name_to_re(x) for x in all_datasets}
149-
all_metrics_re = {x:name_to_re(x) for x in all_metrics}
150+
# all_datasets_re = {x:name_to_re(x) for x in all_datasets}
151+
# all_metrics_re = {x:name_to_re(x) for x in all_metrics}
150152
#all_datasets = set(x for v in merged_p.values() for x in v)
151153

152-
def find_names(text, names_re):
153-
return set(name for name, name_re in names_re.items() if name_re.search(text))
154+
# def find_names(text, names_re):
155+
# return set(name for name, name_re in names_re.items() if name_re.search(text))
156+
157+
158+
def make_trie(names):
159+
trie = ahocorasick.Automaton()
160+
for name in names:
161+
norm = name.replace(" ", "")
162+
trie.add_word(norm, (len(norm), name))
163+
trie.make_automaton()
164+
return trie
165+
166+
167+
single_letter_re = re.compile(r"\b\w\b")
168+
init_letter_re = re.compile(r"\b\w")
169+
end_letter_re = re.compile(r"\w\b")
170+
letter_re = re.compile(r"\w")
171+
172+
173+
def find_names(text, names_trie):
174+
text = text.lower()
175+
profile = letter_re.sub("i", text)
176+
profile = init_letter_re.sub("b", profile)
177+
profile = end_letter_re.sub("e", profile)
178+
profile = single_letter_re.sub("x", profile)
179+
text = text.replace(" ", "")
180+
profile = profile.replace(" ", "")
181+
s = set()
182+
for (end, (l, word)) in names_trie.iter(text):
183+
if profile[end] in ['e', 'x'] and profile[end - l + 1] in ['b', 'x']:
184+
s.add(word)
185+
return s
186+
187+
188+
all_datasets_trie = make_trie(all_datasets)
189+
all_metrics_trie = make_trie(all_metrics)
190+
154191

155192
def find_datasets(text):
156-
return find_names(text, all_datasets_re)
193+
return find_names(text, all_datasets_trie)
157194

158195
def find_metrics(text):
159-
return find_names(text, all_metrics_re)
196+
return find_names(text, all_metrics_trie)
160197

161198
def dummy_item(reason):
162199
return pd.DataFrame(dict(dataset=[reason], task=[reason], metric=[reason], evidence=[""], confidence=[0.0]))
163200

164201

165202

203+
@njit
204+
def compute_logprobs(dataset_metric, reverse_merged_p, reverse_metrics_p, dss, mss, noise, logprobs):
205+
empty = typed.Dict.empty(types.unicode_type, types.float64)
206+
for i, (dataset, metric) in enumerate(dataset_metric):
207+
logprob = 0.0
208+
short_probs = reverse_merged_p.get(dataset, empty)
209+
met_probs = reverse_metrics_p.get(metric, empty)
210+
for ds in dss:
211+
# for abbrv, long_form in abbrvs.items():
212+
# if ds == abbrv:
213+
# ds = long_form
214+
# break
215+
# if merged_p[ds].get('NOMATCH', 0.0) < 0.5:
216+
logprob += np.log(noise * 0.001 + (1 - noise) * short_probs.get(ds, 0.0))
217+
for ms in mss:
218+
logprob += np.log(noise * 0.01 + (1 - noise) * met_probs.get(ms, 0.0))
219+
logprobs[i] += logprob
220+
#logprobs[(dataset, metric)] = logprob
221+
166222

167223
class ContextSearch:
168224
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), debug_gold_df=None):
@@ -174,47 +230,53 @@ def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), debug_gold_df=None):
174230

175231
self.queries = {}
176232
self.taxonomy = taxonomy
233+
self._dataset_metric = typed.List()
234+
for t in self.taxonomy.taxonomy:
235+
self._dataset_metric.append(t)
177236
self.extract_acronyms = AcronymExtractor()
178237
self.context_noise = context_noise
179-
self.reverse_merged_p = reverse_probs(merged_p)
180-
self.reverse_metrics_p = reverse_probs(metrics_p)
238+
self.reverse_merged_p = self._numba_update_nested_dict(reverse_probs(merged_p))
239+
self.reverse_metrics_p = self._numba_update_nested_dict(reverse_probs(metrics_p))
181240
self.debug_gold_df = debug_gold_df
182241

183-
def compute_logprobs(self, dss, mss, abbrvs, noise, logprobs):
184-
for dataset, metric in self.taxonomy.taxonomy:
185-
logprob = logprobs.get((dataset, metric), 1.0)
186-
short_probs = self.reverse_merged_p.get(dataset, {})
187-
met_probs = self.reverse_metrics_p.get(metric, {})
188-
for ds in dss:
189-
ds = normalize_cell(ds)
190-
# for abbrv, long_form in abbrvs.items():
191-
# if ds == abbrv:
192-
# ds = long_form
193-
# break
194-
# if merged_p[ds].get('NOMATCH', 0.0) < 0.5:
195-
logprob += np.log(noise * 0.001 + (1 - noise) * short_probs.get(ds, 0.0))
196-
for ms in mss:
197-
ms = normalize_cell(ms)
198-
logprob += np.log(noise * 0.01 + (1 - noise) * met_probs.get(ms, 0.0))
199-
logprobs[(dataset, metric)] = logprob
242+
def _numba_update_nested_dict(self, nested):
243+
d = typed.Dict()
244+
for key, dct in nested.items():
245+
d2 = typed.Dict()
246+
d2.update(dct)
247+
d[key] = d2
248+
return d
249+
250+
def _numba_extend_list(self, lst):
251+
l = typed.List.empty_list(types.unicode_type)
252+
for x in lst:
253+
l.append(x)
254+
return l
200255

201256
def compute_context_logprobs(self, context, noise, logprobs):
202257
abbrvs = self.extract_acronyms(context)
203258
context = normalize_cell_ws(normalize_dataset(context))
204259
dss = set(find_datasets(context)) | set(abbrvs.keys())
205260
mss = set(find_metrics(context))
206261
dss -= mss
262+
dss = [normalize_cell(ds) for ds in dss]
263+
mss = [normalize_cell(ms) for ms in mss]
207264
###print("dss", dss)
208265
###print("mss", mss)
209-
self.compute_logprobs(dss, mss, abbrvs, noise, logprobs)
266+
dss = self._numba_extend_list(dss)
267+
mss = self._numba_extend_list(mss)
268+
compute_logprobs(self._dataset_metric, self.reverse_merged_p, self.reverse_metrics_p, dss, mss, noise, logprobs)
210269

211270
def match(self, contexts):
212271
assert len(contexts) == len(self.context_noise)
213-
context_logprobs = {}
272+
n = len(self._dataset_metric)
273+
context_logprobs = np.ones(n)
214274

215275
for context, noise in zip(contexts, self.context_noise):
216276
self.compute_context_logprobs(context, noise, context_logprobs)
217-
keys, logprobs = zip(*context_logprobs.items())
277+
keys = self.taxonomy.taxonomy.keys()
278+
logprobs = context_logprobs
279+
#keys, logprobs = zip(*context_logprobs.items())
218280
probs = softmax(np.array(logprobs))
219281
return zip(keys, probs)
220282

0 commit comments

Comments
 (0)