Skip to content

Commit b535afd

Browse files
author
Marcin Kardas
committed
Cache logprobs
1 parent c989173 commit b535afd

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

sota_extractor2/models/linking/context_search.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,17 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb, max_
119119
@njit
120120
def compute_logprobs(taxonomy, tasks, datasets, metrics,
121121
reverse_merged_p, reverse_metrics_p, reverse_task_p,
122-
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs, axes_logprobs,
122+
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb,
123123
max_repetitions):
124124
task_cache = typed.Dict.empty(types.unicode_type, types.float64)
125125
dataset_cache = typed.Dict.empty(types.unicode_type, types.float64)
126126
metric_cache = typed.Dict.empty(types.unicode_type, types.float64)
127+
logprobs = np.zeros(len(taxonomy))
128+
axes_logprobs = (
129+
np.zeros(len(tasks)),
130+
np.zeros(len(datasets)),
131+
np.zeros(len(metrics))
132+
)
127133
for i, (task, dataset, metric) in enumerate(taxonomy):
128134
if dataset not in dataset_cache:
129135
dataset_cache[dataset] = axis_logprobs(dataset, reverse_merged_p, dss, noise, ds_pb, 1)
@@ -141,6 +147,7 @@ def compute_logprobs(taxonomy, tasks, datasets, metrics,
141147

142148
for i, metric in enumerate(metrics):
143149
axes_logprobs[2][i] += metric_cache[metric]
150+
return logprobs, axes_logprobs
144151

145152

146153
def _to_typed_list(iterable):
@@ -160,7 +167,9 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.1, 0.2, 0.2,
160167
tasks_p = \
161168
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in evidence_finder.tasks.items()})[1]
162169

170+
# todo: use LRU cache to avoid OOM
163171
self.queries = {}
172+
self.logprobs_cache = {}
164173
self.taxonomy = taxonomy
165174
self.evidence_finder = evidence_finder
166175

@@ -201,6 +210,11 @@ def _numba_extend_dict(self, dct):
201210
d.update(dct)
202211
return d
203212

213+
def _hash_counter(self, d):
214+
items = list(d.items())
215+
items = sorted(items)
216+
return ";".join([x[0]+":"+str(x[1]) for x in items])
217+
204218
def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs, axes_logprobs):
205219
if isinstance(context, str) or context is None:
206220
context = context or ""
@@ -224,10 +238,20 @@ def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs,
224238
dss = self._numba_extend_dict(dss)
225239
mss = self._numba_extend_dict(mss)
226240
tss = self._numba_extend_dict(tss)
227-
compute_logprobs(self._taxonomy, self._taxonomy_tasks, self._taxonomy_datasets, self._taxonomy_metrics,
228-
self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
229-
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs,
230-
axes_logprobs, self.max_repetitions)
241+
242+
key = (self._hash_counter(tss), self._hash_counter(dss), self._hash_counter(mss), noise, ms_noise, ts_noise)
243+
if key not in self.logprobs_cache:
244+
lp, alp = compute_logprobs(self._taxonomy, self._taxonomy_tasks, self._taxonomy_datasets, self._taxonomy_metrics,
245+
self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
246+
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb,
247+
self.max_repetitions)
248+
self.logprobs_cache[key] = (lp, alp)
249+
else:
250+
lp, alp = self.logprobs_cache[key]
251+
logprobs += lp
252+
axes_logprobs[0] += alp[0]
253+
axes_logprobs[1] += alp[1]
254+
axes_logprobs[2] += alp[2]
231255

232256
def match(self, contexts):
233257
assert len(contexts) == len(self.context_noise)

0 commit comments

Comments
 (0)