@@ -248,9 +248,9 @@ def match(self, contexts):
248
248
axes_probs = [softmax (np .array (a )) for a in axes_context_logprobs ]
249
249
return (
250
250
zip (keys , probs ),
251
- zip (self .taxonomy . tasks , axes_probs [0 ]),
252
- zip (self .taxonomy . datasets , axes_probs [1 ]),
253
- zip (self .taxonomy . metrics , axes_probs [2 ])
251
+ zip (self ._taxonomy_tasks , axes_probs [0 ]),
252
+ zip (self ._taxonomy_datasets , axes_probs [1 ]),
253
+ zip (self ._taxonomy_metrics , axes_probs [2 ])
254
254
)
255
255
256
256
def __call__ (self , query , paper_context , abstract_context , table_context , caption , topk = 1 , debug_info = None ):
@@ -259,9 +259,9 @@ def __call__(self, query, paper_context, abstract_context, table_context, captio
259
259
paper_context = paper_context , abstract_context = abstract_context , table_context = table_context ,
260
260
caption = caption )
261
261
262
- paper_hash = ";" .join ("," .join (s .elements ()) for s in paper_context )
263
- abstract_hash = ";" .join ("," .join (s .elements ()) for s in abstract_context )
264
- mentions_hash = ";" .join ("," .join (s .elements ()) for s in table_context )
262
+ paper_hash = ";" .join ("," .join (sorted ( s .elements () )) for s in paper_context )
263
+ abstract_hash = ";" .join ("," .join (sorted ( s .elements () )) for s in abstract_context )
264
+ mentions_hash = ";" .join ("," .join (sorted ( s .elements () )) for s in table_context )
265
265
key = (paper_hash , abstract_hash , mentions_hash , caption , query , topk )
266
266
###print(f"[DEBUG] {cellstr}")
267
267
###print("[DEBUG]", debug_info)
@@ -282,7 +282,7 @@ def __call__(self, query, paper_context, abstract_context, table_context, captio
282
282
else :
283
283
dists = self .match ((paper_context , abstract_context , table_context , caption , query ))
284
284
285
- all_top_results = [sorted (dist , key = lambda x : x [1 ], reverse = True )[:max (topk , 5 )] for dist in dists ]
285
+ all_top_results = [sorted (list ( dist ) , key = lambda x : x [1 ], reverse = True )[:max (topk , 5 )] for dist in dists ]
286
286
top_results , top_results_t , top_results_d , top_results_m = all_top_results
287
287
288
288
entries = []
@@ -358,8 +358,10 @@ def find_references(self, text, references):
358
358
return set (re .findall (refs , text ))
359
359
360
360
def get_table_contexts (self , paper , tables ):
361
- ref_tables = [table for table in tables if table .figure_id ]
361
+ ref_tables = [table for table in tables if table .figure_id and table . figure_id . replace ( "." , "" ) ]
362
362
refs = [table .figure_id .replace ("." , "" ) for table in ref_tables ]
363
+ if not refs :
364
+ return [[Counter (), Counter (), Counter ()] for table in tables ]
363
365
ref_contexts = {ref : [Counter (), Counter (), Counter ()] for ref in refs }
364
366
if hasattr (paper .text , "fragments" ):
365
367
for fragment in paper .text .fragments :
0 commit comments