Skip to content

Commit 53917e0

Browse files
author
Marcin Kardas
committed
Cache axis logprobs
1 parent 0a55394 commit 53917e0

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

sota_extractor2/models/linking/context_search.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,18 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
117117
@njit
118118
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task_p,
119119
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs):
120+
task_cache = typed.Dict.empty(types.unicode_type, types.float64)
121+
dataset_cache = typed.Dict.empty(types.unicode_type, types.float64)
122+
metric_cache = typed.Dict.empty(types.unicode_type, types.float64)
120123
for i, (task, dataset, metric) in enumerate(taxonomy):
121-
logprob = 0.0
122-
logprob += axis_logprobs(dataset, reverse_merged_p, dss, noise, ds_pb)
123-
logprob += axis_logprobs(metric, reverse_metrics_p, mss, ms_noise, ms_pb)
124-
logprob += axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb)
125-
logprobs[i] += logprob
124+
if dataset not in dataset_cache:
125+
dataset_cache[dataset] = axis_logprobs(dataset, reverse_merged_p, dss, noise, ds_pb)
126+
if metric not in metric_cache:
127+
metric_cache[metric] = axis_logprobs(metric, reverse_metrics_p, mss, ms_noise, ms_pb)
128+
if task not in task_cache:
129+
task_cache[task] = axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb)
130+
131+
logprobs[i] = dataset_cache[dataset] + metric_cache[metric] + task_cache[task]
126132

127133

128134
class ContextSearch:

0 commit comments

Comments
 (0)