diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 93ee6f9..7e6edf8 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -147,8 +147,8 @@ def test_table_teds_metric(self):
self.assertTrue(teds_result.success)
self.assertIsInstance(teds_result.score, float)
# 验证固定内容的确定分数
- self.assertAlmostEqual(teds_result.score, 0.300000, places=5,
- msg=f"table_TEDS分数应该是0.300000,实际: {teds_result.score}")
+ self.assertAlmostEqual(teds_result.score, 0.5199999999999999, places=5,
+ msg=f"table_TEDS分数应该是0.5199999999999999,实际: {teds_result.score}")
# 验证详细信息
self.assertEqual(teds_result.details['content_type'], 'table')
diff --git a/tests/test_teds.py b/tests/test_teds.py
index cc896e4..94ee3ec 100644
--- a/tests/test_teds.py
+++ b/tests/test_teds.py
@@ -208,6 +208,19 @@ def test_very_large_table(self):
self.assertTrue(result.success)
self.assertEqual(result.score, 1.0)
+ def test_teds_structure_same_content_different(self):
+ """测试结构相同但内容不同的表格 - 验证修复后的TEDS不会返回0分"""
+ pred = "
"
+ gt = ""
+
+ result = self.teds_metric.calculate(
+ predicted=pred,
+ groundtruth=gt,
+ table_edit_result=self.valid_table_edit_result
+ )
+ assert result.score == 0.7999999999999999
+
+
class TestTEDSAdvanced(unittest.TestCase):
"""Advanced TEDS functionality tests - 高级功能测试"""
@@ -301,6 +314,18 @@ def test_teds_complex_table(self):
self.assertGreater(result.score, 0.0)
self.assertLess(result.score, 1.0)
+ def test_teds_content_similarity(self):
+ """Test TEDS with similar content but different text - 测试内容相似度"""
+ table1 = ""
+ table2 = ""
+
+ result = self.teds.calculate(
+ table1,
+ table2,
+ table_edit_result=self.valid_table_edit_result
+ )
+ assert result.score == 0.3999999999999999
+
class TestStructureTEDS(unittest.TestCase):
"""Structure-only TEDS tests - 结构化TEDS测试"""
@@ -457,7 +482,7 @@ def run_all_teds_tests():
# Add all test classes
test_classes = [
# 注意:确保TestTEDSBasic已定义或从其他文件导入
- # TestTEDSBasic,
+ TestTEDSBasic,
TestTEDSAdvanced,
TestStructureTEDS,
TestTEDSEdgeCases
diff --git a/webmainbench/metrics/teds_metrics.py b/webmainbench/metrics/teds_metrics.py
index 9ee0958..04cafb8 100644
--- a/webmainbench/metrics/teds_metrics.py
+++ b/webmainbench/metrics/teds_metrics.py
@@ -89,7 +89,7 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
return MetricResult(
metric_name=self.name,
- score=max(0.0, min(1.0, teds_score)),
+ score=max(0.0, min(1.0, teds_score)), # 删除多余的右括号
details=details
)
@@ -288,23 +288,141 @@ def _tree_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
return self._list_edit_distance(children1, children2)
else:
- # Nodes are different, calculate minimum cost
- # Option 1: Replace tree1 with tree2
- cost_replace = 1.0 + self._list_edit_distance(
- tree1.get('children', []),
- tree2.get('children', [])
- )
+ # 检查结构是否相同(忽略文本内容)
+ if self._structure_equal(tree1, tree2):
+ # 结构相同,内容不同,使用内容编辑距离
+ content_distance = self._content_edit_distance(tree1, tree2)
+ children_cost = self._list_edit_distance(
+ tree1.get('children', []),
+ tree2.get('children', [])
+ )
+ return content_distance + children_cost
+ else:
+ # 结构不同,使用原有的删除插入策略
+ # Option 1: Replace tree1 with tree2
+ cost_replace = 1.0 + self._list_edit_distance(
+ tree1.get('children', []),
+ tree2.get('children', [])
+ )
+
+ # Option 2: Delete tree1 and insert tree2
+ cost_delete_insert = (
+ float(self._count_nodes(tree1)) +
+ float(self._count_nodes(tree2))
+ )
+
+ return min(cost_replace, cost_delete_insert)
+
+ def _structure_equal(self, tree1: Dict, tree2: Dict) -> bool:
+ """Check if two trees have identical structure (same tag, attributes)"""
+ if tree1['tag'] != tree2['tag']:
+ return False
+
+ # Compare important attributes
+ attrs1 = tree1.get('attrs', {})
+ attrs2 = tree2.get('attrs', {})
+
+ # Check colspan and rowspan
+ important_attrs = ['colspan', 'rowspan']
+ for attr in important_attrs:
+ if attrs1.get(attr) != attrs2.get(attr):
+ return False
+
+ # 结构相同,忽略文本内容
+ return True
+
+ def _content_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
+ """Calculate content edit distance between two trees with same structure"""
+ if tree1['tag'] != tree2['tag']:
+ return 1.0 # 标签不同,惩罚1分
+
+ # 如果是叶子节点(如td),计算文本内容的编辑距离
+ if tree1['tag'] == 'td' or not tree1.get('children'):
+ text1 = tree1.get('text', '')
+ text2 = tree2.get('text', '')
- # Option 2: Delete tree1 and insert tree2
- cost_delete_insert = (
- float(self._count_nodes(tree1)) +
- float(self._count_nodes(tree2))
- )
+ if text1 == text2:
+ return 0.0 # 内容相同
- return min(cost_replace, cost_delete_insert)
-
+ # 计算文本编辑距离
+ return self._text_edit_distance(text1, text2)
+
+ # 非叶子节点,递归计算子节点的内容编辑距离
+ children1 = tree1.get('children', [])
+ children2 = tree2.get('children', [])
+
+ return self._list_content_edit_distance(children1, children2)
+
+ def _text_edit_distance(self, text1: str, text2: str) -> float:
+ """Calculate normalized edit distance between two text strings"""
+ if not text1 and not text2:
+ return 0.0
+ if not text1 or not text2:
+ return 1.0
+
+ # 计算Levenshtein编辑距离
+ distance = self._levenshtein_distance(text1, text2)
+ max_len = max(len(text1), len(text2))
+
+ # 返回归一化的编辑距离(0-1之间)
+ return float(distance) / max_len if max_len > 0 else 0.0
+
+ def _levenshtein_distance(self, s1: str, s2: str) -> int:
+ """Calculate Levenshtein distance between two strings"""
+ if len(s1) < len(s2):
+ return self._levenshtein_distance(s2, s1)
+
+ if len(s2) == 0:
+ return len(s1)
+
+ previous_row = list(range(len(s2) + 1))
+ for i, c1 in enumerate(s1):
+ current_row = [i + 1]
+ for j, c2 in enumerate(s2):
+ insertions = previous_row[j + 1] + 1
+ deletions = current_row[j] + 1
+ substitutions = previous_row[j] + (c1 != c2)
+ current_row.append(min(insertions, deletions, substitutions))
+ previous_row = current_row
+
+ return previous_row[-1]
+
+ def _list_content_edit_distance(self, list1: List, list2: List) -> float:
+ """Calculate content edit distance between two lists of trees"""
+ m, n = len(list1), len(list2)
+
+ # 初始化DP矩阵
+ dp = [[0.0] * (n + 1) for _ in range(m + 1)]
+
+ # 基础情况
+ for i in range(1, m + 1):
+ dp[i][0] = dp[i-1][0] + self._content_edit_distance(list1[i-1], list2[0]) if n > 0 else 1.0
+
+ for j in range(1, n + 1):
+ dp[0][j] = dp[0][j-1] + self._content_edit_distance(list1[0], list2[j-1]) if m > 0 else 1.0
+
+ # 填充DP矩阵
+ for i in range(1, m + 1):
+ for j in range(1, n + 1):
+ # 内容替换成本
+ subst_cost = self._content_edit_distance(list1[i-1], list2[j-1])
+
+ # 删除成本
+ del_cost = 1.0 # 删除一个节点的内容成本
+
+ # 插入成本
+ ins_cost = 1.0 # 插入一个节点的内容成本
+
+ dp[i][j] = min(
+ dp[i-1][j-1] + subst_cost, # 替换
+ dp[i-1][j] + del_cost, # 删除
+ dp[i][j-1] + ins_cost # 插入
+ )
+
+ return dp[m][n]
+
def _list_edit_distance(self, list1: List, list2: List) -> float:
- """Calculate edit distance between two lists of trees."""
+ """Calculate edit distance between two lists of trees (for structure comparison)"""
m, n = len(list1), len(list2)
# Initialize DP matrix