diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 93ee6f9..7e6edf8 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -147,8 +147,8 @@ def test_table_teds_metric(self): self.assertTrue(teds_result.success) self.assertIsInstance(teds_result.score, float) # 验证固定内容的确定分数 - self.assertAlmostEqual(teds_result.score, 0.300000, places=5, - msg=f"table_TEDS分数应该是0.300000,实际: {teds_result.score}") + self.assertAlmostEqual(teds_result.score, 0.5199999999999999, places=5, + msg=f"table_TEDS分数应该是0.5199999999999999,实际: {teds_result.score}") # 验证详细信息 self.assertEqual(teds_result.details['content_type'], 'table') diff --git a/tests/test_teds.py b/tests/test_teds.py index cc896e4..94ee3ec 100644 --- a/tests/test_teds.py +++ b/tests/test_teds.py @@ -208,6 +208,19 @@ def test_very_large_table(self): self.assertTrue(result.success) self.assertEqual(result.score, 1.0) + def test_teds_structure_same_content_different(self): + """测试结构相同但内容不同的表格 - 验证修复后的TEDS不会返回0分""" + pred = "
我不喜欢你
" + gt = "
我喜欢你
" + + result = self.teds_metric.calculate( + predicted=pred, + groundtruth=gt, + table_edit_result=self.valid_table_edit_result + ) + assert result.score == 0.7999999999999999 + + class TestTEDSAdvanced(unittest.TestCase): """Advanced TEDS functionality tests - 高级功能测试""" @@ -301,6 +314,18 @@ def test_teds_complex_table(self): self.assertGreater(result.score, 0.0) self.assertLess(result.score, 1.0) + def test_teds_content_similarity(self): + """Test TEDS with similar content but different text - 测试内容相似度""" + table1 = "
苹果很好吃香蕉也不错
" + table2 = "
苹果很美味香蕉也很好
" + + result = self.teds.calculate( + table1, + table2, + table_edit_result=self.valid_table_edit_result + ) + assert result.score == 0.3999999999999999 + class TestStructureTEDS(unittest.TestCase): """Structure-only TEDS tests - 结构化TEDS测试""" @@ -457,7 +482,7 @@ def run_all_teds_tests(): # Add all test classes test_classes = [ # 注意:确保TestTEDSBasic已定义或从其他文件导入 - # TestTEDSBasic, + TestTEDSBasic, TestTEDSAdvanced, TestStructureTEDS, TestTEDSEdgeCases diff --git a/webmainbench/metrics/teds_metrics.py b/webmainbench/metrics/teds_metrics.py index 9ee0958..04cafb8 100644 --- a/webmainbench/metrics/teds_metrics.py +++ b/webmainbench/metrics/teds_metrics.py @@ -89,7 +89,7 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric return MetricResult( metric_name=self.name, - score=max(0.0, min(1.0, teds_score)), + score=max(0.0, min(1.0, teds_score)), # 删除多余的右括号 details=details ) @@ -288,23 +288,141 @@ def _tree_edit_distance(self, tree1: Dict, tree2: Dict) -> float: return self._list_edit_distance(children1, children2) else: - # Nodes are different, calculate minimum cost - # Option 1: Replace tree1 with tree2 - cost_replace = 1.0 + self._list_edit_distance( - tree1.get('children', []), - tree2.get('children', []) - ) + # 检查结构是否相同(忽略文本内容) + if self._structure_equal(tree1, tree2): + # 结构相同,内容不同,使用内容编辑距离 + content_distance = self._content_edit_distance(tree1, tree2) + children_cost = self._list_edit_distance( + tree1.get('children', []), + tree2.get('children', []) + ) + return content_distance + children_cost + else: + # 结构不同,使用原有的删除插入策略 + # Option 1: Replace tree1 with tree2 + cost_replace = 1.0 + self._list_edit_distance( + tree1.get('children', []), + tree2.get('children', []) + ) + + # Option 2: Delete tree1 and insert tree2 + cost_delete_insert = ( + float(self._count_nodes(tree1)) + + float(self._count_nodes(tree2)) + ) + + return min(cost_replace, cost_delete_insert) + + def _structure_equal(self, tree1: Dict, tree2: Dict) -> bool: + """Check if two trees have identical structure (same tag, attributes)""" + if tree1['tag'] != tree2['tag']: + return False + + # Compare important attributes + attrs1 = tree1.get('attrs', {}) + attrs2 = tree2.get('attrs', {}) + + # Check colspan and rowspan + important_attrs = ['colspan', 'rowspan'] + for attr in important_attrs: + if attrs1.get(attr) != attrs2.get(attr): + return False + + # 结构相同,忽略文本内容 + return True + + def _content_edit_distance(self, tree1: Dict, tree2: Dict) -> float: + """Calculate content edit distance between two trees with same structure""" + if tree1['tag'] != tree2['tag']: + return 1.0 # 标签不同,惩罚1分 + + # 如果是叶子节点(如td),计算文本内容的编辑距离 + if tree1['tag'] == 'td' or not tree1.get('children'): + text1 = tree1.get('text', '') + text2 = tree2.get('text', '') - # Option 2: Delete tree1 and insert tree2 - cost_delete_insert = ( - float(self._count_nodes(tree1)) + - float(self._count_nodes(tree2)) - ) + if text1 == text2: + return 0.0 # 内容相同 - return min(cost_replace, cost_delete_insert) - + # 计算文本编辑距离 + return self._text_edit_distance(text1, text2) + + # 非叶子节点,递归计算子节点的内容编辑距离 + children1 = tree1.get('children', []) + children2 = tree2.get('children', []) + + return self._list_content_edit_distance(children1, children2) + + def _text_edit_distance(self, text1: str, text2: str) -> float: + """Calculate normalized edit distance between two text strings""" + if not text1 and not text2: + return 0.0 + if not text1 or not text2: + return 1.0 + + # 计算Levenshtein编辑距离 + distance = self._levenshtein_distance(text1, text2) + max_len = max(len(text1), len(text2)) + + # 返回归一化的编辑距离(0-1之间) + return float(distance) / max_len if max_len > 0 else 0.0 + + def _levenshtein_distance(self, s1: str, s2: str) -> int: + """Calculate Levenshtein distance between two strings""" + if len(s1) < len(s2): + return self._levenshtein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] + + def _list_content_edit_distance(self, list1: List, list2: List) -> float: + """Calculate content edit distance between two lists of trees""" + m, n = len(list1), len(list2) + + # 初始化DP矩阵 + dp = [[0.0] * (n + 1) for _ in range(m + 1)] + + # 基础情况 + for i in range(1, m + 1): + dp[i][0] = dp[i-1][0] + self._content_edit_distance(list1[i-1], list2[0]) if n > 0 else 1.0 + + for j in range(1, n + 1): + dp[0][j] = dp[0][j-1] + self._content_edit_distance(list1[0], list2[j-1]) if m > 0 else 1.0 + + # 填充DP矩阵 + for i in range(1, m + 1): + for j in range(1, n + 1): + # 内容替换成本 + subst_cost = self._content_edit_distance(list1[i-1], list2[j-1]) + + # 删除成本 + del_cost = 1.0 # 删除一个节点的内容成本 + + # 插入成本 + ins_cost = 1.0 # 插入一个节点的内容成本 + + dp[i][j] = min( + dp[i-1][j-1] + subst_cost, # 替换 + dp[i-1][j] + del_cost, # 删除 + dp[i][j-1] + ins_cost # 插入 + ) + + return dp[m][n] + def _list_edit_distance(self, list1: List, list2: List) -> float: - """Calculate edit distance between two lists of trees.""" + """Calculate edit distance between two lists of trees (for structure comparison)""" m, n = len(list1), len(list2) # Initialize DP matrix