@@ -119,11 +119,17 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb, max_
119
119
@njit
120
120
def compute_logprobs (taxonomy , tasks , datasets , metrics ,
121
121
reverse_merged_p , reverse_metrics_p , reverse_task_p ,
122
- dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs , axes_logprobs ,
122
+ dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb ,
123
123
max_repetitions ):
124
124
task_cache = typed .Dict .empty (types .unicode_type , types .float64 )
125
125
dataset_cache = typed .Dict .empty (types .unicode_type , types .float64 )
126
126
metric_cache = typed .Dict .empty (types .unicode_type , types .float64 )
127
+ logprobs = np .zeros (len (taxonomy ))
128
+ axes_logprobs = (
129
+ np .zeros (len (tasks )),
130
+ np .zeros (len (datasets )),
131
+ np .zeros (len (metrics ))
132
+ )
127
133
for i , (task , dataset , metric ) in enumerate (taxonomy ):
128
134
if dataset not in dataset_cache :
129
135
dataset_cache [dataset ] = axis_logprobs (dataset , reverse_merged_p , dss , noise , ds_pb , 1 )
@@ -141,6 +147,7 @@ def compute_logprobs(taxonomy, tasks, datasets, metrics,
141
147
142
148
for i , metric in enumerate (metrics ):
143
149
axes_logprobs [2 ][i ] += metric_cache [metric ]
150
+ return logprobs , axes_logprobs
144
151
145
152
146
153
def _to_typed_list (iterable ):
@@ -160,7 +167,9 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.1, 0.2, 0.2,
160
167
tasks_p = \
161
168
get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in evidence_finder .tasks .items ()})[1 ]
162
169
170
+ # todo: use LRU cache to avoid OOM
163
171
self .queries = {}
172
+ self .logprobs_cache = {}
164
173
self .taxonomy = taxonomy
165
174
self .evidence_finder = evidence_finder
166
175
@@ -201,6 +210,11 @@ def _numba_extend_dict(self, dct):
201
210
d .update (dct )
202
211
return d
203
212
213
+ def _hash_counter (self , d ):
214
+ items = list (d .items ())
215
+ items = sorted (items )
216
+ return ";" .join ([x [0 ]+ ":" + str (x [1 ]) for x in items ])
217
+
204
218
def compute_context_logprobs (self , context , noise , ms_noise , ts_noise , logprobs , axes_logprobs ):
205
219
if isinstance (context , str ) or context is None :
206
220
context = context or ""
@@ -224,10 +238,20 @@ def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs,
224
238
dss = self ._numba_extend_dict (dss )
225
239
mss = self ._numba_extend_dict (mss )
226
240
tss = self ._numba_extend_dict (tss )
227
- compute_logprobs (self ._taxonomy , self ._taxonomy_tasks , self ._taxonomy_datasets , self ._taxonomy_metrics ,
228
- self .reverse_merged_p , self .reverse_metrics_p , self .reverse_tasks_p ,
229
- dss , mss , tss , noise , ms_noise , ts_noise , self .ds_pb , self .ms_pb , self .ts_pb , logprobs ,
230
- axes_logprobs , self .max_repetitions )
241
+
242
+ key = (self ._hash_counter (tss ), self ._hash_counter (dss ), self ._hash_counter (mss ), noise , ms_noise , ts_noise )
243
+ if key not in self .logprobs_cache :
244
+ lp , alp = compute_logprobs (self ._taxonomy , self ._taxonomy_tasks , self ._taxonomy_datasets , self ._taxonomy_metrics ,
245
+ self .reverse_merged_p , self .reverse_metrics_p , self .reverse_tasks_p ,
246
+ dss , mss , tss , noise , ms_noise , ts_noise , self .ds_pb , self .ms_pb , self .ts_pb ,
247
+ self .max_repetitions )
248
+ self .logprobs_cache [key ] = (lp , alp )
249
+ else :
250
+ lp , alp = self .logprobs_cache [key ]
251
+ logprobs += lp
252
+ axes_logprobs [0 ] += alp [0 ]
253
+ axes_logprobs [1 ] += alp [1 ]
254
+ axes_logprobs [2 ] += alp [2 ]
231
255
232
256
def match (self , contexts ):
233
257
assert len (contexts ) == len (self .context_noise )
0 commit comments