Skip to content

Commit 4e9cae8

Browse files
committed
Get matched strings
1 parent 40df5f0 commit 4e9cae8

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

label_tables.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def get_tables(tables_dir):
8787
def normalize_float_value(s):
8888
match = metric_value_re.search(s)
8989
if match:
90-
return whitespace_re.sub("", match.group(1)).replace(",", "")
91-
return '-'
90+
return whitespace_re.sub("", match.group(1)).replace(",", ""), match.group(0).strip()
91+
return '-', None
9292

9393

9494
def test_near(x, precise):
@@ -102,7 +102,7 @@ def test_near(x, precise):
102102

103103

104104
def fuzzy_match(metric, metric_value, target_value):
105-
metric_value = normalize_float_value(str(metric_value))
105+
metric_value, _ = normalize_float_value(str(metric_value))
106106
if metric_value in metric_na:
107107
return False
108108
metric_value = Decimal(metric_value)
@@ -164,7 +164,7 @@ def mark_with_comparator(task_name, dataset_name, metric_name, arxiv_id, table,
164164
cell_tags = empty_celltags_like(table)
165165
for col in range(cols):
166166
for row in range(rows):
167-
for val in table.iloc[row, col]:
167+
for val, val_str in table.iloc[row, col]:
168168
for record in values:
169169
if comparator(record.normalized, val):
170170
hits += 1
@@ -177,7 +177,8 @@ def mark_with_comparator(task_name, dataset_name, metric_name, arxiv_id, table,
177177
if arxiv_id == record.arxiv_id:
178178
tags += "<this_paper/>"
179179
tags += f"<comparator>{comp_name}</comparator>" +\
180-
f"<matched_cell>{val}</matched_cell></hit>"
180+
f"<matched_cell>{val}</matched_cell>" +\
181+
f"<matched_str>{val_str}</matched_str></hit>"
181182
cell_tags.iloc[row, col] += tags
182183
return cell_tags, hits
183184

@@ -238,18 +239,16 @@ def match_many(output_dir, task_name, dataset_name, metric_name, tables, values)
238239

239240

240241
def normalize_metric(value):
241-
value = normalize_float_value(str(value))
242+
value, _ = normalize_float_value(str(value))
242243
if value in metric_na:
243244
return Decimal("NaN")
244245
return Decimal(value)
245246

246247

247248
def normalize_cell(cell):
248-
if len(letters_re.findall(cell)) > 2:
249-
return []
250249
matches = metric_value_re.findall(cell)
251250
matches = [normalize_float_value(match[0]) for match in matches]
252-
values = [Decimal(value) for value in matches]
251+
values = [(Decimal(value[0]), value[1]) for value in matches if value not in metric_na]
253252
return values
254253

255254

tables2json.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
def get_celltags(filename):
1313
filename = Path(filename)
1414
if filename.exists():
15-
celltags = pd.read_csv(filename, header=None, dtype=str).fillna('')
15+
16+
try:
17+
celltags = pd.read_csv(filename, header=None, dtype=str).fillna('')
18+
except pd.errors.EmptyDataError:
19+
return pd.DataFrame()
1620
return celltags
1721
else:
1822
return pd.DataFrame()

0 commit comments

Comments
 (0)