@@ -89,7 +89,7 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
8989
9090 return MetricResult (
9191 metric_name = self .name ,
92- score = max (0.0 , min (1.0 , teds_score )),
92+ score = max (0.0 , min (1.0 , teds_score )), # 删除多余的右括号
9393 details = details
9494 )
9595
@@ -288,23 +288,141 @@ def _tree_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
288288
289289 return self ._list_edit_distance (children1 , children2 )
290290 else :
291- # Nodes are different, calculate minimum cost
292- # Option 1: Replace tree1 with tree2
293- cost_replace = 1.0 + self ._list_edit_distance (
294- tree1 .get ('children' , []),
295- tree2 .get ('children' , [])
296- )
291+ # 检查结构是否相同(忽略文本内容)
292+ if self ._structure_equal (tree1 , tree2 ):
293+ # 结构相同,内容不同,使用内容编辑距离
294+ content_distance = self ._content_edit_distance (tree1 , tree2 )
295+ children_cost = self ._list_edit_distance (
296+ tree1 .get ('children' , []),
297+ tree2 .get ('children' , [])
298+ )
299+ return content_distance + children_cost
300+ else :
301+ # 结构不同,使用原有的删除插入策略
302+ # Option 1: Replace tree1 with tree2
303+ cost_replace = 1.0 + self ._list_edit_distance (
304+ tree1 .get ('children' , []),
305+ tree2 .get ('children' , [])
306+ )
307+
308+ # Option 2: Delete tree1 and insert tree2
309+ cost_delete_insert = (
310+ float (self ._count_nodes (tree1 )) +
311+ float (self ._count_nodes (tree2 ))
312+ )
313+
314+ return min (cost_replace , cost_delete_insert )
315+
316+ def _structure_equal (self , tree1 : Dict , tree2 : Dict ) -> bool :
317+ """Check if two trees have identical structure (same tag, attributes)"""
318+ if tree1 ['tag' ] != tree2 ['tag' ]:
319+ return False
320+
321+ # Compare important attributes
322+ attrs1 = tree1 .get ('attrs' , {})
323+ attrs2 = tree2 .get ('attrs' , {})
324+
325+ # Check colspan and rowspan
326+ important_attrs = ['colspan' , 'rowspan' ]
327+ for attr in important_attrs :
328+ if attrs1 .get (attr ) != attrs2 .get (attr ):
329+ return False
330+
331+ # 结构相同,忽略文本内容
332+ return True
333+
334+ def _content_edit_distance (self , tree1 : Dict , tree2 : Dict ) -> float :
335+ """Calculate content edit distance between two trees with same structure"""
336+ if tree1 ['tag' ] != tree2 ['tag' ]:
337+ return 1.0 # 标签不同,惩罚1分
338+
339+ # 如果是叶子节点(如td),计算文本内容的编辑距离
340+ if tree1 ['tag' ] == 'td' or not tree1 .get ('children' ):
341+ text1 = tree1 .get ('text' , '' )
342+ text2 = tree2 .get ('text' , '' )
297343
298- # Option 2: Delete tree1 and insert tree2
299- cost_delete_insert = (
300- float (self ._count_nodes (tree1 )) +
301- float (self ._count_nodes (tree2 ))
302- )
344+ if text1 == text2 :
345+ return 0.0 # 内容相同
303346
304- return min (cost_replace , cost_delete_insert )
305-
347+ # 计算文本编辑距离
348+ return self ._text_edit_distance (text1 , text2 )
349+
350+ # 非叶子节点,递归计算子节点的内容编辑距离
351+ children1 = tree1 .get ('children' , [])
352+ children2 = tree2 .get ('children' , [])
353+
354+ return self ._list_content_edit_distance (children1 , children2 )
355+
356+ def _text_edit_distance (self , text1 : str , text2 : str ) -> float :
357+ """Calculate normalized edit distance between two text strings"""
358+ if not text1 and not text2 :
359+ return 0.0
360+ if not text1 or not text2 :
361+ return 1.0
362+
363+ # 计算Levenshtein编辑距离
364+ distance = self ._levenshtein_distance (text1 , text2 )
365+ max_len = max (len (text1 ), len (text2 ))
366+
367+ # 返回归一化的编辑距离(0-1之间)
368+ return float (distance ) / max_len if max_len > 0 else 0.0
369+
370+ def _levenshtein_distance (self , s1 : str , s2 : str ) -> int :
371+ """Calculate Levenshtein distance between two strings"""
372+ if len (s1 ) < len (s2 ):
373+ return self ._levenshtein_distance (s2 , s1 )
374+
375+ if len (s2 ) == 0 :
376+ return len (s1 )
377+
378+ previous_row = list (range (len (s2 ) + 1 ))
379+ for i , c1 in enumerate (s1 ):
380+ current_row = [i + 1 ]
381+ for j , c2 in enumerate (s2 ):
382+ insertions = previous_row [j + 1 ] + 1
383+ deletions = current_row [j ] + 1
384+ substitutions = previous_row [j ] + (c1 != c2 )
385+ current_row .append (min (insertions , deletions , substitutions ))
386+ previous_row = current_row
387+
388+ return previous_row [- 1 ]
389+
390+ def _list_content_edit_distance (self , list1 : List , list2 : List ) -> float :
391+ """Calculate content edit distance between two lists of trees"""
392+ m , n = len (list1 ), len (list2 )
393+
394+ # 初始化DP矩阵
395+ dp = [[0.0 ] * (n + 1 ) for _ in range (m + 1 )]
396+
397+ # 基础情况
398+ for i in range (1 , m + 1 ):
399+ dp [i ][0 ] = dp [i - 1 ][0 ] + self ._content_edit_distance (list1 [i - 1 ], list2 [0 ]) if n > 0 else 1.0
400+
401+ for j in range (1 , n + 1 ):
402+ dp [0 ][j ] = dp [0 ][j - 1 ] + self ._content_edit_distance (list1 [0 ], list2 [j - 1 ]) if m > 0 else 1.0
403+
404+ # 填充DP矩阵
405+ for i in range (1 , m + 1 ):
406+ for j in range (1 , n + 1 ):
407+ # 内容替换成本
408+ subst_cost = self ._content_edit_distance (list1 [i - 1 ], list2 [j - 1 ])
409+
410+ # 删除成本
411+ del_cost = 1.0 # 删除一个节点的内容成本
412+
413+ # 插入成本
414+ ins_cost = 1.0 # 插入一个节点的内容成本
415+
416+ dp [i ][j ] = min (
417+ dp [i - 1 ][j - 1 ] + subst_cost , # 替换
418+ dp [i - 1 ][j ] + del_cost , # 删除
419+ dp [i ][j - 1 ] + ins_cost # 插入
420+ )
421+
422+ return dp [m ][n ]
423+
306424 def _list_edit_distance (self , list1 : List , list2 : List ) -> float :
307- """Calculate edit distance between two lists of trees. """
425+ """Calculate edit distance between two lists of trees (for structure comparison) """
308426 m , n = len (list1 ), len (list2 )
309427
310428 # Initialize DP matrix
0 commit comments