Skip to content

Commit 0a55394

Browse files
author
Marcin Kardas
committed
Extract axis logprobs computation
1 parent 650050e commit 0a55394

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

sota_extractor2/models/linking/context_search.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,26 @@ def _init_structs(self, taxonomy):
103103
self.all_metrics_trie = EvidenceFinder.make_trie(self.all_metrics)
104104
self.all_tasks_trie = EvidenceFinder.make_trie(self.all_tasks)
105105

106+
107+
@njit(inline="always")
108+
def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
109+
logprob = 0.0
110+
empty = typed.Dict.empty(types.unicode_type, types.float64)
111+
short_probs = reverse_probs.get(evidences_for, empty)
112+
for evidence in found_evidences:
113+
logprob += np.log(noise * pb + (1 - noise) * short_probs.get(evidence, 0.0))
114+
return logprob
115+
116+
106117
@njit
107118
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task_p,
108119
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs):
109-
empty = typed.Dict.empty(types.unicode_type, types.float64)
110120
for i, (task, dataset, metric) in enumerate(taxonomy):
111121
logprob = 0.0
112-
short_probs = reverse_merged_p.get(dataset, empty)
113-
for ds in dss:
114-
logprob += np.log(noise * ds_pb + (1 - noise) * short_probs.get(ds, 0.0))
115-
met_probs = reverse_metrics_p.get(metric, empty)
116-
for ms in mss:
117-
logprob += np.log(ms_noise * ms_pb + (1 - ms_noise) * met_probs.get(ms, 0.0))
118-
task_probs = reverse_task_p.get(task, empty)
119-
for ts in tss:
120-
logprob += np.log(ts_noise * ts_pb + (1 - ts_noise) * task_probs.get(ts, 0.0))
122+
logprob += axis_logprobs(dataset, reverse_merged_p, dss, noise, ds_pb)
123+
logprob += axis_logprobs(metric, reverse_metrics_p, mss, ms_noise, ms_pb)
124+
logprob += axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb)
121125
logprobs[i] += logprob
122-
#logprobs[(dataset, metric)] = logprob
123126

124127

125128
class ContextSearch:

0 commit comments

Comments
 (0)