Skip to content

Commit 3cdd9e6

Browse files
authored
Merge pull request #33 from SHUzhangshuo/main
fix: table TEDS bug
2 parents a3ab1a2 + 687df1e commit 3cdd9e6

File tree

3 files changed

+161
-18
lines changed

3 files changed

+161
-18
lines changed

tests/test_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def test_table_teds_metric(self):
147147
self.assertTrue(teds_result.success)
148148
self.assertIsInstance(teds_result.score, float)
149149
# 验证固定内容的确定分数
150-
self.assertAlmostEqual(teds_result.score, 0.300000, places=5,
151-
msg=f"table_TEDS分数应该是0.300000,实际: {teds_result.score}")
150+
self.assertAlmostEqual(teds_result.score, 0.5199999999999999, places=5,
151+
msg=f"table_TEDS分数应该是0.5199999999999999,实际: {teds_result.score}")
152152

153153
# 验证详细信息
154154
self.assertEqual(teds_result.details['content_type'], 'table')

tests/test_teds.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,19 @@ def test_very_large_table(self):
208208
self.assertTrue(result.success)
209209
self.assertEqual(result.score, 1.0)
210210

211+
def test_teds_structure_same_content_different(self):
212+
"""测试结构相同但内容不同的表格 - 验证修复后的TEDS不会返回0分"""
213+
pred = "<table><tr><td>我不喜欢你</td></tr></table>"
214+
gt = "<table><tr><td>我喜欢你</td></tr></table>"
215+
216+
result = self.teds_metric.calculate(
217+
predicted=pred,
218+
groundtruth=gt,
219+
table_edit_result=self.valid_table_edit_result
220+
)
221+
assert result.score == 0.7999999999999999
222+
223+
211224

212225
class TestTEDSAdvanced(unittest.TestCase):
213226
"""Advanced TEDS functionality tests - 高级功能测试"""
@@ -301,6 +314,18 @@ def test_teds_complex_table(self):
301314
self.assertGreater(result.score, 0.0)
302315
self.assertLess(result.score, 1.0)
303316

317+
def test_teds_content_similarity(self):
318+
"""Test TEDS with similar content but different text - 测试内容相似度"""
319+
table1 = "<table><tr><td>苹果很好吃</td><td>香蕉也不错</td></tr></table>"
320+
table2 = "<table><tr><td>苹果很美味</td><td>香蕉也很好</td></tr></table>"
321+
322+
result = self.teds.calculate(
323+
table1,
324+
table2,
325+
table_edit_result=self.valid_table_edit_result
326+
)
327+
assert result.score == 0.3999999999999999
328+
304329

305330
class TestStructureTEDS(unittest.TestCase):
306331
"""Structure-only TEDS tests - 结构化TEDS测试"""
@@ -457,7 +482,7 @@ def run_all_teds_tests():
457482
# Add all test classes
458483
test_classes = [
459484
# 注意:确保TestTEDSBasic已定义或从其他文件导入
460-
# TestTEDSBasic,
485+
TestTEDSBasic,
461486
TestTEDSAdvanced,
462487
TestStructureTEDS,
463488
TestTEDSEdgeCases

webmainbench/metrics/teds_metrics.py

Lines changed: 133 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
8989

9090
return MetricResult(
9191
metric_name=self.name,
92-
score=max(0.0, min(1.0, teds_score)),
92+
score=max(0.0, min(1.0, teds_score)), # 删除多余的右括号
9393
details=details
9494
)
9595

@@ -288,23 +288,141 @@ def _tree_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
288288

289289
return self._list_edit_distance(children1, children2)
290290
else:
291-
# Nodes are different, calculate minimum cost
292-
# Option 1: Replace tree1 with tree2
293-
cost_replace = 1.0 + self._list_edit_distance(
294-
tree1.get('children', []),
295-
tree2.get('children', [])
296-
)
291+
# 检查结构是否相同(忽略文本内容)
292+
if self._structure_equal(tree1, tree2):
293+
# 结构相同,内容不同,使用内容编辑距离
294+
content_distance = self._content_edit_distance(tree1, tree2)
295+
children_cost = self._list_edit_distance(
296+
tree1.get('children', []),
297+
tree2.get('children', [])
298+
)
299+
return content_distance + children_cost
300+
else:
301+
# 结构不同,使用原有的删除插入策略
302+
# Option 1: Replace tree1 with tree2
303+
cost_replace = 1.0 + self._list_edit_distance(
304+
tree1.get('children', []),
305+
tree2.get('children', [])
306+
)
307+
308+
# Option 2: Delete tree1 and insert tree2
309+
cost_delete_insert = (
310+
float(self._count_nodes(tree1)) +
311+
float(self._count_nodes(tree2))
312+
)
313+
314+
return min(cost_replace, cost_delete_insert)
315+
316+
def _structure_equal(self, tree1: Dict, tree2: Dict) -> bool:
317+
"""Check if two trees have identical structure (same tag, attributes)"""
318+
if tree1['tag'] != tree2['tag']:
319+
return False
320+
321+
# Compare important attributes
322+
attrs1 = tree1.get('attrs', {})
323+
attrs2 = tree2.get('attrs', {})
324+
325+
# Check colspan and rowspan
326+
important_attrs = ['colspan', 'rowspan']
327+
for attr in important_attrs:
328+
if attrs1.get(attr) != attrs2.get(attr):
329+
return False
330+
331+
# 结构相同,忽略文本内容
332+
return True
333+
334+
def _content_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
335+
"""Calculate content edit distance between two trees with same structure"""
336+
if tree1['tag'] != tree2['tag']:
337+
return 1.0 # 标签不同,惩罚1分
338+
339+
# 如果是叶子节点(如td),计算文本内容的编辑距离
340+
if tree1['tag'] == 'td' or not tree1.get('children'):
341+
text1 = tree1.get('text', '')
342+
text2 = tree2.get('text', '')
297343

298-
# Option 2: Delete tree1 and insert tree2
299-
cost_delete_insert = (
300-
float(self._count_nodes(tree1)) +
301-
float(self._count_nodes(tree2))
302-
)
344+
if text1 == text2:
345+
return 0.0 # 内容相同
303346

304-
return min(cost_replace, cost_delete_insert)
305-
347+
# 计算文本编辑距离
348+
return self._text_edit_distance(text1, text2)
349+
350+
# 非叶子节点,递归计算子节点的内容编辑距离
351+
children1 = tree1.get('children', [])
352+
children2 = tree2.get('children', [])
353+
354+
return self._list_content_edit_distance(children1, children2)
355+
356+
def _text_edit_distance(self, text1: str, text2: str) -> float:
357+
"""Calculate normalized edit distance between two text strings"""
358+
if not text1 and not text2:
359+
return 0.0
360+
if not text1 or not text2:
361+
return 1.0
362+
363+
# 计算Levenshtein编辑距离
364+
distance = self._levenshtein_distance(text1, text2)
365+
max_len = max(len(text1), len(text2))
366+
367+
# 返回归一化的编辑距离(0-1之间)
368+
return float(distance) / max_len if max_len > 0 else 0.0
369+
370+
def _levenshtein_distance(self, s1: str, s2: str) -> int:
371+
"""Calculate Levenshtein distance between two strings"""
372+
if len(s1) < len(s2):
373+
return self._levenshtein_distance(s2, s1)
374+
375+
if len(s2) == 0:
376+
return len(s1)
377+
378+
previous_row = list(range(len(s2) + 1))
379+
for i, c1 in enumerate(s1):
380+
current_row = [i + 1]
381+
for j, c2 in enumerate(s2):
382+
insertions = previous_row[j + 1] + 1
383+
deletions = current_row[j] + 1
384+
substitutions = previous_row[j] + (c1 != c2)
385+
current_row.append(min(insertions, deletions, substitutions))
386+
previous_row = current_row
387+
388+
return previous_row[-1]
389+
390+
def _list_content_edit_distance(self, list1: List, list2: List) -> float:
391+
"""Calculate content edit distance between two lists of trees"""
392+
m, n = len(list1), len(list2)
393+
394+
# 初始化DP矩阵
395+
dp = [[0.0] * (n + 1) for _ in range(m + 1)]
396+
397+
# 基础情况
398+
for i in range(1, m + 1):
399+
dp[i][0] = dp[i-1][0] + self._content_edit_distance(list1[i-1], list2[0]) if n > 0 else 1.0
400+
401+
for j in range(1, n + 1):
402+
dp[0][j] = dp[0][j-1] + self._content_edit_distance(list1[0], list2[j-1]) if m > 0 else 1.0
403+
404+
# 填充DP矩阵
405+
for i in range(1, m + 1):
406+
for j in range(1, n + 1):
407+
# 内容替换成本
408+
subst_cost = self._content_edit_distance(list1[i-1], list2[j-1])
409+
410+
# 删除成本
411+
del_cost = 1.0 # 删除一个节点的内容成本
412+
413+
# 插入成本
414+
ins_cost = 1.0 # 插入一个节点的内容成本
415+
416+
dp[i][j] = min(
417+
dp[i-1][j-1] + subst_cost, # 替换
418+
dp[i-1][j] + del_cost, # 删除
419+
dp[i][j-1] + ins_cost # 插入
420+
)
421+
422+
return dp[m][n]
423+
306424
def _list_edit_distance(self, list1: List, list2: List) -> float:
307-
"""Calculate edit distance between two lists of trees."""
425+
"""Calculate edit distance between two lists of trees (for structure comparison)"""
308426
m, n = len(list1), len(list2)
309427

310428
# Initialize DP matrix

0 commit comments

Comments
 (0)