@@ -117,12 +117,18 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
117
117
@njit
118
118
def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , reverse_task_p ,
119
119
dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs ):
120
+ task_cache = typed .Dict .empty (types .unicode_type , types .float64 )
121
+ dataset_cache = typed .Dict .empty (types .unicode_type , types .float64 )
122
+ metric_cache = typed .Dict .empty (types .unicode_type , types .float64 )
120
123
for i , (task , dataset , metric ) in enumerate (taxonomy ):
121
- logprob = 0.0
122
- logprob += axis_logprobs (dataset , reverse_merged_p , dss , noise , ds_pb )
123
- logprob += axis_logprobs (metric , reverse_metrics_p , mss , ms_noise , ms_pb )
124
- logprob += axis_logprobs (task , reverse_task_p , tss , ts_noise , ts_pb )
125
- logprobs [i ] += logprob
124
+ if dataset not in dataset_cache :
125
+ dataset_cache [dataset ] = axis_logprobs (dataset , reverse_merged_p , dss , noise , ds_pb )
126
+ if metric not in metric_cache :
127
+ metric_cache [metric ] = axis_logprobs (metric , reverse_metrics_p , mss , ms_noise , ms_pb )
128
+ if task not in task_cache :
129
+ task_cache [task ] = axis_logprobs (task , reverse_task_p , tss , ts_noise , ts_pb )
130
+
131
+ logprobs [i ] = dataset_cache [dataset ] + metric_cache [metric ] + task_cache [task ]
126
132
127
133
128
134
class ContextSearch :
0 commit comments