Skip to content

Commit 5026084

Browse files
committed
Add ensemble model tag
1 parent 2cea810 commit 5026084

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

sota_extractor2/data/paper_collection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=Tr
9393
path = Path(path)
9494
if annotations_path is None:
9595
annotations_path = path / "structure-annotations.json"
96+
else:
97+
annotations_path = Path(annotations_path)
9698
if load_texts:
9799
texts = _load_texts(path, jobs)
98100
else:

sota_extractor2/data/structure.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def consume_cells(table):
2020
for col_id, cell in enumerate(row):
2121
vals = [
2222
remove_text_styles(remove_references(cell.raw_value)),
23-
"",
23+
cell.gold_tags,
2424
cell.refs[0] if cell.refs else "",
2525
cell.layout,
2626
bool(style_tags_re.search(cell.raw_value))
@@ -103,13 +103,13 @@ def fix_reference_hightlight(s):
103103
"cell_layout", "cell_styles", "this_paper", "row", "col", "row_context", "col_context", "ext_id"]
104104

105105

106-
def create_evidence_records(textfrag, cell, paper, table):
106+
def create_evidence_records(textfrag, cell, paper_id, table):
107107
for text_highlited in textfrag.meta['highlight']['text']:
108108
text_highlited = fix_reference_hightlight(fix_refs(text_highlited))
109109
text = highlight_re.sub("", text_highlited)
110110
text_sha1 = hashlib.sha1(text.encode("utf-8")).hexdigest()
111111

112-
cell_ext_id = f"{paper.paper_id}/{table.name}/{cell.row}/{cell.col}"
112+
cell_ext_id = f"{paper_id}/{table.name}/{cell.row}/{cell.col}"
113113

114114
yield {"text_sha1": text_sha1,
115115
"text_highlited": text_highlited,
@@ -120,7 +120,7 @@ def create_evidence_records(textfrag, cell, paper, table):
120120
"cell_reference": cell.vals[2],
121121
"cell_layout": cell.vals[3],
122122
"cell_styles": cell.vals[4],
123-
"this_paper": textfrag.paper_id == paper.paper_id,
123+
"this_paper": textfrag.paper_id == paper_id,
124124
"row": cell.row,
125125
"col": cell.col,
126126
"row_context": " border ".join([str(s) for s in table.matrix.values[cell.row]]),
@@ -137,23 +137,22 @@ def filter_cells(cell_content):
137137
interesting_types = ["model-paper", "model-best", "model-competing", "dataset", "dataset-sub", "dataset-task"]
138138

139139

140-
def evidence_for_table(paper, table, paper_limit, corpus_limit):
140+
def evidence_for_table(paper_id, table, paper_limit, corpus_limit):
141141
records = [
142142
record
143143
for cell in consume_cells(table)
144-
for evidence in fetch_evidence(cell.vals[0], cell.vals[2], paper_id=paper.paper_id, table_name=table.name,
144+
for evidence in fetch_evidence(cell.vals[0], cell.vals[2], paper_id=paper_id, table_name=table.name,
145145
row=cell.row, col=cell.col, paper_limit=paper_limit, corpus_limit=corpus_limit)
146-
for record in create_evidence_records(evidence, cell, paper=paper, table=table)
146+
for record in create_evidence_records(evidence, cell, paper_id=paper_id, table=table)
147147
]
148148
df = pd.DataFrame.from_records(records, columns=evidence_columns)
149149
return df
150150

151151

152-
def prepare_data(paper, tables, csv_path, limit_type='interesting'):
153-
data = [evidence_for_table(paper, table,
152+
def prepare_data(tables, csv_path):
153+
data = [evidence_for_table(table.paper_id, table,
154154
paper_limit=100,
155-
corpus_limit=20,
156-
limit_type=limit_type) for table in progress_bar(tables)]
155+
corpus_limit=20) for table in progress_bar(tables)]
157156
if len(data):
158157
df = pd.concat(data)
159158
else:
@@ -173,7 +172,7 @@ def __init__(self):
173172
setup_default_connection()
174173

175174
def __call__(self, paper, tables, paper_limit=30, corpus_limit=10):
176-
dfs = [evidence_for_table(paper, table, paper_limit, corpus_limit) for table in tables]
175+
dfs = [evidence_for_table(paper.paper_id, table, paper_limit, corpus_limit) for table in tables]
177176
if len(dfs):
178177
return pd.concat(dfs)
179178
return pd.DataFrame(columns=evidence_columns)

sota_extractor2/models/structure/experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Labels(Enum):
2525
"dataset-sub": Labels.DATASET.value,
2626
"model-paper": Labels.PAPER_MODEL.value,
2727
"model-best": Labels.PAPER_MODEL.value,
28+
"model-ensemble": Labels.PAPER_MODEL.value,
2829
"model-competing": Labels.COMPETING_MODEL.value,
2930
"dataset-metric": Labels.METRIC.value,
3031
# "model-params": Labels.PARAMS.value

0 commit comments

Comments
 (0)