Skip to content

Commit b0d0dae

Browse files
author
Marcin Kardas
committed
Return the top-k best proposals
1 parent e21a3a0 commit b0d0dae

File tree

3 files changed

+44
-34
lines changed

3 files changed

+44
-34
lines changed

sota_extractor2/models/linking/bm25_naive.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def handle_pm(value):
174174
'confidence', 'parsed', 'struct_model_type', 'struct_dataset']
175175

176176

177-
def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking, datasets):
177+
def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking, datasets, topk=1):
178178
# %%
179179
# Proposal generation
180180
def consume_cells(matrix):
@@ -217,11 +217,6 @@ def annotations(r, c, type='model'):
217217

218218
def linked_proposals(proposals):
219219
for prop in proposals:
220-
df = taxonomy_linking(prop.dataset, datasets, desc, debug_info=prop)
221-
assert len(df) == 1
222-
223-
metric = df['metric'][0]
224-
225220
# heuristyic to handle accuracy vs error
226221
first_num = (list(handle_pm(prop.raw_value)) + [0])[0]
227222
format = "{x}"
@@ -234,24 +229,27 @@ def linked_proposals(proposals):
234229
if '%' in prop.raw_value:
235230
format += '%'
236231

237-
# if ("error" in metric or "Error" in metric) and (first_num > 0.5):
238-
if (metric.strip().lower() == "error") and (first_num > 0.5):
239-
metric = "Accuracy"
240-
241-
linked = {
242-
'dataset': df['dataset'][0],
243-
'metric': metric,
244-
'task': df['task'][0],
245-
'format': format,
246-
'raw_value': prop.raw_value,
247-
'model': prop.model_name,
248-
'model_type': prop.model_type,
249-
'cell_ext_id': prop.cell.cell_ext_id,
250-
'confidence': df['confidence'][0],
251-
'struct_model_type': prop.model_type,
252-
'struct_dataset': prop.dataset
253-
}
254-
yield linked
232+
df = taxonomy_linking(prop.dataset, datasets, desc, topk=topk, debug_info=prop)
233+
for _, row in df.iterrows():
234+
metric = row['metric']
235+
# if ("error" in metric or "Error" in metric) and (first_num > 0.5):
236+
if (metric.strip().lower() == "error") and (first_num > 0.5):
237+
metric = "Accuracy"
238+
239+
linked = {
240+
'dataset': row['dataset'],
241+
'metric': metric,
242+
'task': row['task'],
243+
'format': format,
244+
'raw_value': prop.raw_value,
245+
'model': prop.model_name,
246+
'model_type': prop.model_type,
247+
'cell_ext_id': prop.cell.cell_ext_id,
248+
'confidence': row['confidence'],
249+
'struct_model_type': prop.model_type,
250+
'struct_dataset': prop.dataset
251+
}
252+
yield linked
255253

256254
# specify columns in case there's no proposal
257255

@@ -264,7 +262,7 @@ def linked_proposals(proposals):
264262

265263

266264
def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=MatchSearch(),
267-
dataset_extractor=None):
265+
dataset_extractor=None, topk=1):
268266
# dataset_extractor=DatasetExtractor()):
269267
proposals = []
270268
datasets = dataset_extractor.from_paper(paper)
@@ -277,7 +275,11 @@ def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=Mat
277275
table_ext_id = f"{paper_ext_id}/{table.name}"
278276

279277
if 'sota' in tags and 'no_sota_records' not in tags: # only parse tables that are marked as sota
280-
proposals.append(generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking, datasets))
278+
proposals.append(
279+
generate_proposals_for_table(
280+
table_ext_id, matrix, structure, desc, taxonomy_linking, datasets, topk=topk
281+
)
282+
)
281283
if len(proposals):
282284
return pd.concat(proposals)
283285
return pd.DataFrame(columns=proposal_columns)

sota_extractor2/models/linking/context_search.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def match(self, contexts):
308308
probs = softmax(np.array(logprobs))
309309
return zip(keys, probs)
310310

311-
def __call__(self, query, datasets, caption, debug_info=None):
311+
def __call__(self, query, datasets, caption, topk=1, debug_info=None):
312312
cellstr = debug_info.cell.cell_ext_id
313313
pipeline_logger("linking::taxonomy_linking::call", ext_id=cellstr, query=query, datasets=datasets, caption=caption)
314314
datasets = " ".join(datasets)
@@ -331,10 +331,10 @@ def __call__(self, query, datasets, caption, debug_info=None):
331331
p = self.queries[key]
332332
else:
333333
dist = self.match(key)
334-
topk = sorted(dist, key=lambda x: x[1], reverse=True)[0:5]
334+
top_results = sorted(dist, key=lambda x: x[1], reverse=True)[:max(topk, 5)]
335335

336336
entries = []
337-
for it, prob in topk:
337+
for it, prob in top_results:
338338
task, dataset, metric = it
339339
entry = dict(task=task, dataset=dataset, metric=metric)
340340
entry.update({"evidence": "", "confidence": prob})
@@ -363,8 +363,8 @@ def __call__(self, query, datasets, caption, debug_info=None):
363363
else:
364364
print("[EA] No gold sota record found for the cell")
365365
# end of error analysis only
366-
pipeline_logger("linking::taxonomy_linking::topk", ext_id=cellstr, topk=p)
367-
return p.head(1)
366+
pipeline_logger("linking::taxonomy_linking::topk", ext_id=cellstr, topk=p.head(5))
367+
return p.head(topk)
368368

369369

370370
# todo: compare regex approach (old) with find_datasets(.) (current)

sota_extractor2/models/linking/linker.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,18 @@ def __init__(self, name, taxonomy_linking, dataset_extractor):
1010
self.dataset_extractor = dataset_extractor
1111
self.__name__ = name
1212

13-
def __call__(self, paper, tables):
13+
def __call__(self, paper, tables, topk=1):
1414
pipeline_logger(f"{Linker.step}::call", paper=paper, tables=tables)
1515
proposals = linked_proposals(paper.paper_id, paper, tables,
1616
taxonomy_linking=self.taxonomy_linking,
17-
dataset_extractor=self.dataset_extractor).set_index('cell_ext_id')
18-
pipeline_logger(f"{Linker.step}::linked", paper=paper, tables=tables, proposals=proposals)
17+
dataset_extractor=self.dataset_extractor,
18+
topk=topk)
19+
20+
if topk == 1:
21+
proposals = proposals.set_index('cell_ext_id')
22+
best = proposals
23+
else:
24+
best = proposals.groupby('cell_ext_id').head(1).set_index('cell_ext_id')
25+
26+
pipeline_logger(f"{Linker.step}::linked", paper=paper, tables=tables, proposals=best)
1927
return proposals

0 commit comments

Comments
 (0)