Skip to content

Commit cfbfd02

Browse files
authored
Merge pull request #35 from SHUzhangshuo/main
feat(metrics): implement comprehensive memoization for TEDS algorithm
2 parents d90a1e6 + 05e4b4c commit cfbfd02

File tree

2 files changed

+146
-69
lines changed

2 files changed

+146
-69
lines changed

tests/test_teds.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
TEDS (树编辑距离相似性) 指标的综合测试
55
"""
66

7+
import re
78
import unittest
89
import sys
10+
import time
911
from pathlib import Path
1012

1113
# Add project root to path

webmainbench/metrics/teds_metrics.py

Lines changed: 144 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from typing import Dict, Any, List, Optional
1212
import re
13+
from functools import lru_cache
1314
from lxml import etree, html
1415
from lxml.html import HtmlElement
1516
from bs4 import BeautifulSoup
@@ -87,9 +88,12 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
8788
"algorithm": "TEDS"
8889
}
8990

91+
# 清理缓存以释放内存
92+
self._clear_memoization_cache()
93+
9094
return MetricResult(
9195
metric_name=self.name,
92-
score=max(0.0, min(1.0, teds_score)), # 删除多余的右括号
96+
score=max(0.0, min(1.0, teds_score)),
9397
details=details
9498
)
9599

@@ -268,50 +272,64 @@ def _count_nodes(self, tree: Dict) -> int:
268272

269273
def _tree_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
270274
"""
271-
Calculate tree edit distance using dynamic programming.
275+
Calculate tree edit distance using dynamic programming with memoization.
272276
273277
This is a simplified version of the tree edit distance algorithm.
274278
For production use, consider using more sophisticated algorithms.
275279
"""
280+
# 初始化记忆化缓存
281+
if not hasattr(self, '_memo_cache'):
282+
self._memo_cache = {}
283+
284+
# 创建缓存键
285+
cache_key = (id(tree1), id(tree2))
286+
if cache_key in self._memo_cache:
287+
return self._memo_cache[cache_key]
288+
289+
# 计算编辑距离
276290
if tree1 is None and tree2 is None:
277-
return 0.0
278-
if tree1 is None:
279-
return float(self._count_nodes(tree2))
280-
if tree2 is None:
281-
return float(self._count_nodes(tree1))
282-
283-
# Check if nodes are the same
284-
if self._nodes_equal(tree1, tree2):
285-
# Nodes are equal, calculate cost for children
286-
children1 = tree1.get('children', [])
287-
children2 = tree2.get('children', [])
288-
289-
return self._list_edit_distance(children1, children2)
291+
result = 0.0
292+
elif tree1 is None:
293+
result = float(self._count_nodes(tree2))
294+
elif tree2 is None:
295+
result = float(self._count_nodes(tree1))
290296
else:
291-
# 检查结构是否相同(忽略文本内容)
292-
if self._structure_equal(tree1, tree2):
293-
# 结构相同,内容不同,使用内容编辑距离
294-
content_distance = self._content_edit_distance(tree1, tree2)
295-
children_cost = self._list_edit_distance(
296-
tree1.get('children', []),
297-
tree2.get('children', [])
298-
)
299-
return content_distance + children_cost
300-
else:
301-
# 结构不同,使用原有的删除插入策略
302-
# Option 1: Replace tree1 with tree2
303-
cost_replace = 1.0 + self._list_edit_distance(
304-
tree1.get('children', []),
305-
tree2.get('children', [])
306-
)
297+
# Check if nodes are the same
298+
if self._nodes_equal(tree1, tree2):
299+
# Nodes are equal, calculate cost for children
300+
children1 = tree1.get('children', [])
301+
children2 = tree2.get('children', [])
307302

308-
# Option 2: Delete tree1 and insert tree2
309-
cost_delete_insert = (
310-
float(self._count_nodes(tree1)) +
311-
float(self._count_nodes(tree2))
312-
)
313-
314-
return min(cost_replace, cost_delete_insert)
303+
result = self._list_edit_distance(children1, children2)
304+
else:
305+
# 检查结构是否相同(忽略文本内容)
306+
if self._structure_equal(tree1, tree2):
307+
# 结构相同,内容不同,使用内容编辑距离
308+
content_distance = self._content_edit_distance(tree1, tree2)
309+
children_cost = self._list_edit_distance(
310+
tree1.get('children', []),
311+
tree2.get('children', [])
312+
)
313+
result = content_distance + children_cost
314+
else:
315+
# 结构不同,使用原有的删除插入策略
316+
# Option 1: Replace tree1 with tree2
317+
cost_replace = 1.0 + self._list_edit_distance(
318+
tree1.get('children', []),
319+
tree2.get('children', [])
320+
)
321+
322+
# Option 2: Delete tree1 and insert tree2
323+
cost_delete_insert = (
324+
float(self._count_nodes(tree1)) +
325+
float(self._count_nodes(tree2))
326+
)
327+
328+
result = min(cost_replace, cost_delete_insert)
329+
330+
# 缓存结果
331+
self._memo_cache[cache_key] = result
332+
return result
315333

316334
def _structure_equal(self, tree1: Dict, tree2: Dict) -> bool:
317335
"""Check if two trees have identical structure (same tag, attributes)"""
@@ -332,45 +350,77 @@ def _structure_equal(self, tree1: Dict, tree2: Dict) -> bool:
332350
return True
333351

334352
def _content_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
335-
"""Calculate content edit distance between two trees with same structure"""
336-
if tree1['tag'] != tree2['tag']:
337-
return 1.0 # 标签不同,惩罚1分
353+
"""Calculate content edit distance between two trees with same structure with memoization"""
354+
# 初始化内容编辑距离缓存
355+
if not hasattr(self, '_content_memo_cache'):
356+
self._content_memo_cache = {}
338357

339-
# 如果是叶子节点(如td),计算文本内容的编辑距离
340-
if tree1['tag'] == 'td' or not tree1.get('children'):
341-
text1 = tree1.get('text', '')
342-
text2 = tree2.get('text', '')
343-
344-
if text1 == text2:
345-
return 0.0 # 内容相同
346-
347-
# 计算文本编辑距离
348-
return self._text_edit_distance(text1, text2)
358+
# 创建缓存键
359+
cache_key = (id(tree1), id(tree2))
360+
if cache_key in self._content_memo_cache:
361+
return self._content_memo_cache[cache_key]
349362

350-
# 非叶子节点,递归计算子节点的内容编辑距离
351-
children1 = tree1.get('children', [])
352-
children2 = tree2.get('children', [])
363+
if tree1['tag'] != tree2['tag']:
364+
result = 1.0 # 标签不同,惩罚1分
365+
else:
366+
# 如果是叶子节点(如td),计算文本内容的编辑距离
367+
if tree1['tag'] == 'td' or not tree1.get('children'):
368+
text1 = tree1.get('text', '')
369+
text2 = tree2.get('text', '')
370+
371+
if text1 == text2:
372+
result = 0.0 # 内容相同
373+
else:
374+
# 计算文本编辑距离
375+
result = self._text_edit_distance(text1, text2)
376+
else:
377+
# 非叶子节点,递归计算子节点的内容编辑距离
378+
children1 = tree1.get('children', [])
379+
children2 = tree2.get('children', [])
380+
381+
result = self._list_content_edit_distance(children1, children2)
353382

354-
return self._list_content_edit_distance(children1, children2)
383+
# 缓存结果
384+
self._content_memo_cache[cache_key] = result
385+
return result
355386

356387
def _text_edit_distance(self, text1: str, text2: str) -> float:
357-
"""Calculate normalized edit distance between two text strings"""
358-
if not text1 and not text2:
359-
return 0.0
360-
if not text1 or not text2:
361-
return 1.0
388+
"""Calculate normalized edit distance between two text strings with memoization"""
389+
# 初始化文本编辑距离缓存
390+
if not hasattr(self, '_text_memo_cache'):
391+
self._text_memo_cache = {}
392+
393+
# 创建缓存键
394+
cache_key = (text1, text2)
395+
if cache_key in self._text_memo_cache:
396+
return self._text_memo_cache[cache_key]
362397

363-
# 计算Levenshtein编辑距离
364-
distance = self._levenshtein_distance(text1, text2)
365-
max_len = max(len(text1), len(text2))
398+
if not text1 and not text2:
399+
result = 0.0
400+
elif not text1 or not text2:
401+
result = 1.0
402+
else:
403+
# 计算Levenshtein编辑距离
404+
distance = self._levenshtein_distance(text1, text2)
405+
max_len = max(len(text1), len(text2))
406+
407+
# 返回归一化的编辑距离(0-1之间)
408+
result = float(distance) / max_len if max_len > 0 else 0.0
366409

367-
# 返回归一化的编辑距离(0-1之间)
368-
return float(distance) / max_len if max_len > 0 else 0.0
410+
# 缓存结果
411+
self._text_memo_cache[cache_key] = result
412+
return result
369413

370414
def _levenshtein_distance(self, s1: str, s2: str) -> int:
371-
"""Calculate Levenshtein distance between two strings"""
415+
"""Calculate Levenshtein distance between two strings with memoization"""
416+
# 使用lru_cache装饰器进行缓存
417+
return self._cached_levenshtein_distance(s1, s2)
418+
419+
@lru_cache(maxsize=10000)
420+
def _cached_levenshtein_distance(self, s1: str, s2: str) -> int:
421+
"""Cached version of Levenshtein distance calculation"""
372422
if len(s1) < len(s2):
373-
return self._levenshtein_distance(s2, s1)
423+
return self._cached_levenshtein_distance(s2, s1)
374424

375425
if len(s2) == 0:
376426
return len(s1)
@@ -422,7 +472,16 @@ def _list_content_edit_distance(self, list1: List, list2: List) -> float:
422472
return dp[m][n]
423473

424474
def _list_edit_distance(self, list1: List, list2: List) -> float:
425-
"""Calculate edit distance between two lists of trees (for structure comparison)"""
475+
"""Calculate edit distance between two lists of trees (for structure comparison) with memoization"""
476+
# 初始化列表编辑距离缓存
477+
if not hasattr(self, '_list_memo_cache'):
478+
self._list_memo_cache = {}
479+
480+
# 创建缓存键
481+
cache_key = (id(list1), id(list2))
482+
if cache_key in self._list_memo_cache:
483+
return self._list_memo_cache[cache_key]
484+
426485
m, n = len(list1), len(list2)
427486

428487
# Initialize DP matrix
@@ -453,7 +512,10 @@ def _list_edit_distance(self, list1: List, list2: List) -> float:
453512
dp[i][j-1] + ins_cost # Insert
454513
)
455514

456-
return dp[m][n]
515+
result = dp[m][n]
516+
# 缓存结果
517+
self._list_memo_cache[cache_key] = result
518+
return result
457519

458520
def _nodes_equal(self, node1: Dict, node2: Dict) -> bool:
459521
"""Check if two tree nodes are equal."""
@@ -476,6 +538,19 @@ def _nodes_equal(self, node1: Dict, node2: Dict) -> bool:
476538
return False
477539

478540
return True
541+
542+
def _clear_memoization_cache(self):
543+
"""清理所有记忆化缓存以释放内存"""
544+
if hasattr(self, '_memo_cache'):
545+
self._memo_cache.clear()
546+
if hasattr(self, '_list_memo_cache'):
547+
self._list_memo_cache.clear()
548+
if hasattr(self, '_content_memo_cache'):
549+
self._content_memo_cache.clear()
550+
if hasattr(self, '_text_memo_cache'):
551+
self._text_memo_cache.clear()
552+
# 清理lru_cache
553+
self._cached_levenshtein_distance.cache_clear()
479554

480555

481556
class StructureTEDSMetric(TEDSMetric):

0 commit comments

Comments
 (0)