@@ -116,8 +116,9 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
116
116
117
117
# compute log-probabilities in a given context and add them to logprobs
118
118
@njit
119
- def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , reverse_task_p ,
120
- dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs ):
119
+ def compute_logprobs (taxonomy , tasks , datasets , metrics ,
120
+ reverse_merged_p , reverse_metrics_p , reverse_task_p ,
121
+ dss , mss , tss , noise , ms_noise , ts_noise , ds_pb , ms_pb , ts_pb , logprobs , axes_logprobs ):
121
122
task_cache = typed .Dict .empty (types .unicode_type , types .float64 )
122
123
dataset_cache = typed .Dict .empty (types .unicode_type , types .float64 )
123
124
metric_cache = typed .Dict .empty (types .unicode_type , types .float64 )
@@ -130,6 +131,21 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task
130
131
task_cache [task ] = axis_logprobs (task , reverse_task_p , tss , ts_noise , ts_pb )
131
132
132
133
logprobs [i ] += dataset_cache [dataset ] + metric_cache [metric ] + task_cache [task ]
134
+ for i , task in enumerate (tasks ):
135
+ axes_logprobs [0 ][i ] += task_cache [task ]
136
+
137
+ for i , dataset in enumerate (datasets ):
138
+ axes_logprobs [1 ][i ] += dataset_cache [dataset ]
139
+
140
+ for i , metric in enumerate (metrics ):
141
+ axes_logprobs [2 ][i ] += metric_cache [metric ]
142
+
143
+
144
+ def _to_typed_list (iterable ):
145
+ l = typed .List ()
146
+ for i in iterable :
147
+ l .append (i )
148
+ return l
133
149
134
150
135
151
class ContextSearch :
@@ -145,9 +161,12 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), met
145
161
self .queries = {}
146
162
self .taxonomy = taxonomy
147
163
self .evidence_finder = evidence_finder
148
- self ._taxonomy = typed .List ()
149
- for t in self .taxonomy .taxonomy :
150
- self ._taxonomy .append (t )
164
+
165
+ self ._taxonomy = _to_typed_list (self .taxonomy .taxonomy )
166
+ self ._taxonomy_tasks = _to_typed_list (self .taxonomy .tasks )
167
+ self ._taxonomy_datasets = _to_typed_list (self .taxonomy .datasets )
168
+ self ._taxonomy_metrics = _to_typed_list (self .taxonomy .metrics )
169
+
151
170
self .extract_acronyms = AcronymExtractor ()
152
171
self .context_noise = context_noise
153
172
self .metrics_noise = metrics_noise if metrics_noise else context_noise
@@ -174,10 +193,10 @@ def _numba_extend_list(self, lst):
174
193
l .append (x )
175
194
return l
176
195
177
- def compute_context_logprobs (self , context , noise , ms_noise , ts_noise , logprobs ):
196
+ def compute_context_logprobs (self , context , noise , ms_noise , ts_noise , logprobs , axes_logprobs ):
178
197
context = context or ""
179
198
abbrvs = self .extract_acronyms (context )
180
- context = normalize_cell_ws (normalize_dataset (context ))
199
+ context = normalize_cell_ws (normalize_dataset_ws (context ))
181
200
dss = set (self .evidence_finder .find_datasets (context )) | set (abbrvs .keys ())
182
201
mss = set (self .evidence_finder .find_metrics (context ))
183
202
tss = set (self .evidence_finder .find_tasks (context ))
@@ -191,21 +210,34 @@ def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs)
191
210
dss = self ._numba_extend_list (dss )
192
211
mss = self ._numba_extend_list (mss )
193
212
tss = self ._numba_extend_list (tss )
194
- compute_logprobs (self ._taxonomy , self .reverse_merged_p , self .reverse_metrics_p , self .reverse_tasks_p ,
195
- dss , mss , tss , noise , ms_noise , ts_noise , self .ds_pb , self .ms_pb , self .ts_pb , logprobs )
213
+ compute_logprobs (self ._taxonomy , self ._taxonomy_tasks , self ._taxonomy_datasets , self ._taxonomy_metrics ,
214
+ self .reverse_merged_p , self .reverse_metrics_p , self .reverse_tasks_p ,
215
+ dss , mss , tss , noise , ms_noise , ts_noise , self .ds_pb , self .ms_pb , self .ts_pb , logprobs ,
216
+ axes_logprobs )
196
217
197
218
def match (self , contexts ):
198
219
assert len (contexts ) == len (self .context_noise )
199
220
n = len (self ._taxonomy )
200
221
context_logprobs = np .zeros (n )
222
+ axes_context_logprobs = _to_typed_list ([
223
+ np .zeros (len (self ._taxonomy_tasks )),
224
+ np .zeros (len (self ._taxonomy_datasets )),
225
+ np .zeros (len (self ._taxonomy_metrics )),
226
+ ])
201
227
202
228
for context , noise , ms_noise , ts_noise in zip (contexts , self .context_noise , self .metrics_noise , self .task_noise ):
203
- self .compute_context_logprobs (context , noise , ms_noise , ts_noise , context_logprobs )
229
+ self .compute_context_logprobs (context , noise , ms_noise , ts_noise , context_logprobs , axes_context_logprobs )
204
230
keys = self .taxonomy .taxonomy
205
231
logprobs = context_logprobs
206
232
#keys, logprobs = zip(*context_logprobs.items())
207
233
probs = softmax (np .array (logprobs ))
208
- return zip (keys , probs )
234
+ axes_probs = [softmax (np .array (a )) for a in axes_context_logprobs ]
235
+ return (
236
+ zip (keys , probs ),
237
+ zip (self .taxonomy .tasks , axes_probs [0 ]),
238
+ zip (self .taxonomy .datasets , axes_probs [1 ]),
239
+ zip (self .taxonomy .metrics , axes_probs [2 ])
240
+ )
209
241
210
242
def __call__ (self , query , datasets , caption , topk = 1 , debug_info = None ):
211
243
cellstr = debug_info .cell .cell_ext_id
@@ -229,8 +261,10 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
229
261
###print("Taking result from cache")
230
262
p = self .queries [key ]
231
263
else :
232
- dist = self .match ((datasets , caption , query ))
233
- top_results = sorted (dist , key = lambda x : x [1 ], reverse = True )[:max (topk , 5 )]
264
+ dists = self .match ((datasets , caption , query ))
265
+
266
+ all_top_results = [sorted (dist , key = lambda x : x [1 ], reverse = True )[:max (topk , 5 )] for dist in dists ]
267
+ top_results , top_results_t , top_results_d , top_results_m = all_top_results
234
268
235
269
entries = []
236
270
for it , prob in top_results :
@@ -239,6 +273,16 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
239
273
entry .update ({"evidence" : "" , "confidence" : prob })
240
274
entries .append (entry )
241
275
276
+ # entries = []
277
+ # for i in range(5):
278
+ # best_independent = dict(
279
+ # task=top_results_t[i][0],
280
+ # dataset=top_results_d[i][0],
281
+ # metric=top_results_m[i][0])
282
+ # best_independent.update({"evidence": "", "confidence": top_results_t[i][1]})
283
+ # entries.append(best_independent)
284
+ #entries = [best_independent] + entries
285
+
242
286
# best, best_p = sorted(dist, key=lambda x: x[1], reverse=True)[0]
243
287
# entry = et[best]
244
288
# p = pd.DataFrame({k:[v] for k, v in entry.items()})
@@ -283,5 +327,5 @@ def from_paper(self, paper):
283
327
return self (text )
284
328
285
329
def __call__ (self , text ):
286
- text = normalize_cell_ws (normalize_dataset (text ))
330
+ text = normalize_cell_ws (normalize_dataset_ws (text ))
287
331
return self .evidence_finder .find_datasets (text ) | self .evidence_finder .find_tasks (text )
0 commit comments