@@ -103,23 +103,26 @@ def _init_structs(self, taxonomy):
103
103
self .all_metrics_trie = EvidenceFinder .make_trie (self .all_metrics )
104
104
self .all_tasks_trie = EvidenceFinder .make_trie (self .all_tasks )
105
105
106
+
107
+ @njit (inline = "always" )
108
+ def axis_logprobs (evidences_for , reverse_probs , found_evidences , noise , pb ):
109
+ logprob = 0.0
110
+ empty = typed .Dict .empty (types .unicode_type , types .float64 )
111
+ short_probs = reverse_probs .get (evidences_for , empty )
112
+ for evidence in found_evidences :
113
+ logprob += np .log (noise * pb + (1 - noise ) * short_probs .get (evidence , 0.0 ))
114
+ return logprob
115
+
116
+
106
117
@njit
107
118
def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , reverse_task_p ,
108
119
dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs ):
109
- empty = typed .Dict .empty (types .unicode_type , types .float64 )
110
120
for i , (task , dataset , metric ) in enumerate (taxonomy ):
111
121
logprob = 0.0
112
- short_probs = reverse_merged_p .get (dataset , empty )
113
- for ds in dss :
114
- logprob += np .log (noise * ds_pb + (1 - noise ) * short_probs .get (ds , 0.0 ))
115
- met_probs = reverse_metrics_p .get (metric , empty )
116
- for ms in mss :
117
- logprob += np .log (ms_noise * ms_pb + (1 - ms_noise ) * met_probs .get (ms , 0.0 ))
118
- task_probs = reverse_task_p .get (task , empty )
119
- for ts in tss :
120
- logprob += np .log (ts_noise * ts_pb + (1 - ts_noise ) * task_probs .get (ts , 0.0 ))
122
+ logprob += axis_logprobs (dataset , reverse_merged_p , dss , noise , ds_pb )
123
+ logprob += axis_logprobs (metric , reverse_metrics_p , mss , ms_noise , ms_pb )
124
+ logprob += axis_logprobs (task , reverse_task_p , tss , ts_noise , ts_pb )
121
125
logprobs [i ] += logprob
122
- #logprobs[(dataset, metric)] = logprob
123
126
124
127
125
128
class ContextSearch :
0 commit comments