1010
1111from typing import Dict , Any , List , Optional
1212import re
13+ from functools import lru_cache
1314from lxml import etree , html
1415from lxml .html import HtmlElement
1516from bs4 import BeautifulSoup
@@ -87,9 +88,12 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
8788 "algorithm" : "TEDS"
8889 }
8990
91+ # 清理缓存以释放内存
92+ self ._clear_memoization_cache ()
93+
9094 return MetricResult (
9195 metric_name = self .name ,
92- score = max (0.0 , min (1.0 , teds_score )), # 删除多余的右括号
96+ score = max (0.0 , min (1.0 , teds_score )),
9397 details = details
9498 )
9599
@@ -268,50 +272,64 @@ def _count_nodes(self, tree: Dict) -> int:
268272
269273 def _tree_edit_distance (self , tree1 : Dict , tree2 : Dict ) -> float :
270274 """
271- Calculate tree edit distance using dynamic programming.
275+ Calculate tree edit distance using dynamic programming with memoization .
272276
273277 This is a simplified version of the tree edit distance algorithm.
274278 For production use, consider using more sophisticated algorithms.
275279 """
280+ # 初始化记忆化缓存
281+ if not hasattr (self , '_memo_cache' ):
282+ self ._memo_cache = {}
283+
284+ # 创建缓存键
285+ cache_key = (id (tree1 ), id (tree2 ))
286+ if cache_key in self ._memo_cache :
287+ return self ._memo_cache [cache_key ]
288+
289+ # 计算编辑距离
276290 if tree1 is None and tree2 is None :
277- return 0.0
278- if tree1 is None :
279- return float (self ._count_nodes (tree2 ))
280- if tree2 is None :
281- return float (self ._count_nodes (tree1 ))
282-
283- # Check if nodes are the same
284- if self ._nodes_equal (tree1 , tree2 ):
285- # Nodes are equal, calculate cost for children
286- children1 = tree1 .get ('children' , [])
287- children2 = tree2 .get ('children' , [])
288-
289- return self ._list_edit_distance (children1 , children2 )
291+ result = 0.0
292+ elif tree1 is None :
293+ result = float (self ._count_nodes (tree2 ))
294+ elif tree2 is None :
295+ result = float (self ._count_nodes (tree1 ))
290296 else :
291- # 检查结构是否相同(忽略文本内容)
292- if self ._structure_equal (tree1 , tree2 ):
293- # 结构相同,内容不同,使用内容编辑距离
294- content_distance = self ._content_edit_distance (tree1 , tree2 )
295- children_cost = self ._list_edit_distance (
296- tree1 .get ('children' , []),
297- tree2 .get ('children' , [])
298- )
299- return content_distance + children_cost
300- else :
301- # 结构不同,使用原有的删除插入策略
302- # Option 1: Replace tree1 with tree2
303- cost_replace = 1.0 + self ._list_edit_distance (
304- tree1 .get ('children' , []),
305- tree2 .get ('children' , [])
306- )
297+ # Check if nodes are the same
298+ if self ._nodes_equal (tree1 , tree2 ):
299+ # Nodes are equal, calculate cost for children
300+ children1 = tree1 .get ('children' , [])
301+ children2 = tree2 .get ('children' , [])
307302
308- # Option 2: Delete tree1 and insert tree2
309- cost_delete_insert = (
310- float (self ._count_nodes (tree1 )) +
311- float (self ._count_nodes (tree2 ))
312- )
313-
314- return min (cost_replace , cost_delete_insert )
303+ result = self ._list_edit_distance (children1 , children2 )
304+ else :
305+ # 检查结构是否相同(忽略文本内容)
306+ if self ._structure_equal (tree1 , tree2 ):
307+ # 结构相同,内容不同,使用内容编辑距离
308+ content_distance = self ._content_edit_distance (tree1 , tree2 )
309+ children_cost = self ._list_edit_distance (
310+ tree1 .get ('children' , []),
311+ tree2 .get ('children' , [])
312+ )
313+ result = content_distance + children_cost
314+ else :
315+ # 结构不同,使用原有的删除插入策略
316+ # Option 1: Replace tree1 with tree2
317+ cost_replace = 1.0 + self ._list_edit_distance (
318+ tree1 .get ('children' , []),
319+ tree2 .get ('children' , [])
320+ )
321+
322+ # Option 2: Delete tree1 and insert tree2
323+ cost_delete_insert = (
324+ float (self ._count_nodes (tree1 )) +
325+ float (self ._count_nodes (tree2 ))
326+ )
327+
328+ result = min (cost_replace , cost_delete_insert )
329+
330+ # 缓存结果
331+ self ._memo_cache [cache_key ] = result
332+ return result
315333
316334 def _structure_equal (self , tree1 : Dict , tree2 : Dict ) -> bool :
317335 """Check if two trees have identical structure (same tag, attributes)"""
@@ -332,45 +350,77 @@ def _structure_equal(self, tree1: Dict, tree2: Dict) -> bool:
332350 return True
333351
334352 def _content_edit_distance (self , tree1 : Dict , tree2 : Dict ) -> float :
335- """Calculate content edit distance between two trees with same structure"""
336- if tree1 ['tag' ] != tree2 ['tag' ]:
337- return 1.0 # 标签不同,惩罚1分
353+ """Calculate content edit distance between two trees with same structure with memoization"""
354+ # 初始化内容编辑距离缓存
355+ if not hasattr (self , '_content_memo_cache' ):
356+ self ._content_memo_cache = {}
338357
339- # 如果是叶子节点(如td),计算文本内容的编辑距离
340- if tree1 ['tag' ] == 'td' or not tree1 .get ('children' ):
341- text1 = tree1 .get ('text' , '' )
342- text2 = tree2 .get ('text' , '' )
343-
344- if text1 == text2 :
345- return 0.0 # 内容相同
346-
347- # 计算文本编辑距离
348- return self ._text_edit_distance (text1 , text2 )
358+ # 创建缓存键
359+ cache_key = (id (tree1 ), id (tree2 ))
360+ if cache_key in self ._content_memo_cache :
361+ return self ._content_memo_cache [cache_key ]
349362
350- # 非叶子节点,递归计算子节点的内容编辑距离
351- children1 = tree1 .get ('children' , [])
352- children2 = tree2 .get ('children' , [])
363+ if tree1 ['tag' ] != tree2 ['tag' ]:
364+ result = 1.0 # 标签不同,惩罚1分
365+ else :
366+ # 如果是叶子节点(如td),计算文本内容的编辑距离
367+ if tree1 ['tag' ] == 'td' or not tree1 .get ('children' ):
368+ text1 = tree1 .get ('text' , '' )
369+ text2 = tree2 .get ('text' , '' )
370+
371+ if text1 == text2 :
372+ result = 0.0 # 内容相同
373+ else :
374+ # 计算文本编辑距离
375+ result = self ._text_edit_distance (text1 , text2 )
376+ else :
377+ # 非叶子节点,递归计算子节点的内容编辑距离
378+ children1 = tree1 .get ('children' , [])
379+ children2 = tree2 .get ('children' , [])
380+
381+ result = self ._list_content_edit_distance (children1 , children2 )
353382
354- return self ._list_content_edit_distance (children1 , children2 )
383+ # 缓存结果
384+ self ._content_memo_cache [cache_key ] = result
385+ return result
355386
356387 def _text_edit_distance (self , text1 : str , text2 : str ) -> float :
357- """Calculate normalized edit distance between two text strings"""
358- if not text1 and not text2 :
359- return 0.0
360- if not text1 or not text2 :
361- return 1.0
388+ """Calculate normalized edit distance between two text strings with memoization"""
389+ # 初始化文本编辑距离缓存
390+ if not hasattr (self , '_text_memo_cache' ):
391+ self ._text_memo_cache = {}
392+
393+ # 创建缓存键
394+ cache_key = (text1 , text2 )
395+ if cache_key in self ._text_memo_cache :
396+ return self ._text_memo_cache [cache_key ]
362397
363- # 计算Levenshtein编辑距离
364- distance = self ._levenshtein_distance (text1 , text2 )
365- max_len = max (len (text1 ), len (text2 ))
398+ if not text1 and not text2 :
399+ result = 0.0
400+ elif not text1 or not text2 :
401+ result = 1.0
402+ else :
403+ # 计算Levenshtein编辑距离
404+ distance = self ._levenshtein_distance (text1 , text2 )
405+ max_len = max (len (text1 ), len (text2 ))
406+
407+ # 返回归一化的编辑距离(0-1之间)
408+ result = float (distance ) / max_len if max_len > 0 else 0.0
366409
367- # 返回归一化的编辑距离(0-1之间)
368- return float (distance ) / max_len if max_len > 0 else 0.0
410+ # 缓存结果
411+ self ._text_memo_cache [cache_key ] = result
412+ return result
369413
370414 def _levenshtein_distance (self , s1 : str , s2 : str ) -> int :
371- """Calculate Levenshtein distance between two strings"""
415+ """Calculate Levenshtein distance between two strings with memoization"""
416+ # 使用lru_cache装饰器进行缓存
417+ return self ._cached_levenshtein_distance (s1 , s2 )
418+
419+ @lru_cache (maxsize = 10000 )
420+ def _cached_levenshtein_distance (self , s1 : str , s2 : str ) -> int :
421+ """Cached version of Levenshtein distance calculation"""
372422 if len (s1 ) < len (s2 ):
373- return self ._levenshtein_distance (s2 , s1 )
423+ return self ._cached_levenshtein_distance (s2 , s1 )
374424
375425 if len (s2 ) == 0 :
376426 return len (s1 )
@@ -422,7 +472,16 @@ def _list_content_edit_distance(self, list1: List, list2: List) -> float:
422472 return dp [m ][n ]
423473
424474 def _list_edit_distance (self , list1 : List , list2 : List ) -> float :
425- """Calculate edit distance between two lists of trees (for structure comparison)"""
475+ """Calculate edit distance between two lists of trees (for structure comparison) with memoization"""
476+ # 初始化列表编辑距离缓存
477+ if not hasattr (self , '_list_memo_cache' ):
478+ self ._list_memo_cache = {}
479+
480+ # 创建缓存键
481+ cache_key = (id (list1 ), id (list2 ))
482+ if cache_key in self ._list_memo_cache :
483+ return self ._list_memo_cache [cache_key ]
484+
426485 m , n = len (list1 ), len (list2 )
427486
428487 # Initialize DP matrix
@@ -453,7 +512,10 @@ def _list_edit_distance(self, list1: List, list2: List) -> float:
453512 dp [i ][j - 1 ] + ins_cost # Insert
454513 )
455514
456- return dp [m ][n ]
515+ result = dp [m ][n ]
516+ # 缓存结果
517+ self ._list_memo_cache [cache_key ] = result
518+ return result
457519
458520 def _nodes_equal (self , node1 : Dict , node2 : Dict ) -> bool :
459521 """Check if two tree nodes are equal."""
@@ -476,6 +538,19 @@ def _nodes_equal(self, node1: Dict, node2: Dict) -> bool:
476538 return False
477539
478540 return True
541+
542+ def _clear_memoization_cache (self ):
543+ """清理所有记忆化缓存以释放内存"""
544+ if hasattr (self , '_memo_cache' ):
545+ self ._memo_cache .clear ()
546+ if hasattr (self , '_list_memo_cache' ):
547+ self ._list_memo_cache .clear ()
548+ if hasattr (self , '_content_memo_cache' ):
549+ self ._content_memo_cache .clear ()
550+ if hasattr (self , '_text_memo_cache' ):
551+ self ._text_memo_cache .clear ()
552+ # 清理lru_cache
553+ self ._cached_levenshtein_distance .cache_clear ()
479554
480555
481556class StructureTEDSMetric (TEDSMetric ):
0 commit comments