diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cb6f11d..e259983 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,6 +35,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -e . + pip install -r requirements.txt pip install pytest pytest-cov pytest-xdist coverage - name: Install optional dependencies (ignore failures) diff --git a/examples/basic_usage.py b/examples/basic_usage.py index b75d533..c8f44a8 100755 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -9,14 +9,14 @@ # 导入 WebMainBench 模块 from webmainbench import ( DataLoader, DataSaver, BenchmarkDataset, DataSample, - ExtractorFactory, Evaluator, + ExtractorFactory, Evaluator, format_results, setup_logging ) def create_sample_dataset(): """创建示例数据集""" - + # 创建示例数据 - 包含多种内容类型(代码、公式、表格等) samples = [ { @@ -51,7 +51,8 @@ def greet(name): "groundtruth_content_list": [ {"type": "heading", "content": "Python编程教程", "level": 1}, {"type": "paragraph", "content": "这是一个Python基础教程,展示如何定义函数。"}, - {"type": "code", "content": 'def greet(name):\n """问候函数"""\n return f"Hello, {name}!"\n\n# 使用示例\nresult = greet("World")\nprint(result)'}, + {"type": "code", + "content": 'def greet(name):\n """问候函数"""\n return f"Hello, {name}!"\n\n# 使用示例\nresult = greet("World")\nprint(result)'}, {"type": "paragraph", "content": "这个函数可以用来问候任何人。"} ], "url": "https://python-tutorial.example.com/functions", @@ -180,9 +181,11 @@ def greet(name): {"type": "heading", "content": "数据分析报告", "level": 1}, {"type": "paragraph", "content": "以下是2024年第一季度的销售数据分析。"}, {"type": "heading", "content": "数据处理代码", "level": 2}, - {"type": "code", "content": "import pandas as pd\nimport numpy as np\n\n# 读取数据\ndf = pd.read_csv('sales_q1_2024.csv')\n\n# 计算统计信息\nmonthly_avg = df.groupby('month')['sales'].mean()\nprint(f\"平均销售额: {monthly_avg}\")"}, + {"type": "code", + "content": "import pandas as pd\nimport numpy as np\n\n# 读取数据\ndf = pd.read_csv('sales_q1_2024.csv')\n\n# 计算统计信息\nmonthly_avg = df.groupby('month')['sales'].mean()\nprint(f\"平均销售额: {monthly_avg}\")"}, {"type": "heading", "content": "销售统计", "level": 2}, - {"type": "table", "content": "| 月份 | 销售额(万元) | 增长率 |\n|------|-------------|--------|\n| 1月 | 120.5 | +15.2% |\n| 2月 | 135.8 | +12.7% |\n| 3月 | 148.3 | +9.2% |"}, + {"type": "table", + "content": "| 月份 | 销售额(万元) | 增长率 |\n|------|-------------|--------|\n| 1月 | 120.5 | +15.2% |\n| 2月 | 135.8 | +12.7% |\n| 3月 | 148.3 | +9.2% |"}, {"type": "paragraph", "content": "标准差公式:σ = √(Σ(xi - μ)² / n)"}, {"type": "paragraph", "content": "总体来看,第一季度销售表现良好,呈现稳定增长趋势。"} ], @@ -208,12 +211,12 @@ def greet(name):
def quicksort(arr):
     if len(arr) <= 1:
         return arr
-    
+
     pivot = arr[len(arr) // 2]
     left = [x for x in arr if x < pivot]
     middle = [x for x in arr if x == pivot]
     right = [x for x in arr if x > pivot]
-    
+
     return quicksort(left) + middle + quicksort(right)

复杂度对比

@@ -235,12 +238,12 @@ def greet(name): def quicksort(arr): if len(arr) <= 1: return arr - + pivot = arr[len(arr) // 2] left = [x for x in arr if x < pivot] middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] - + return quicksort(left) + middle + quicksort(right) ``` @@ -259,9 +262,11 @@ def quicksort(arr): {"type": "heading", "content": "算法复杂度分析", "level": 1}, {"type": "paragraph", "content": "这里介绍常见算法的时间复杂度。"}, {"type": "heading", "content": "快速排序实现", "level": 2}, - {"type": "code", "content": "def quicksort(arr):\n if len(arr) <= 1:\n return arr\n \n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n \n return quicksort(left) + middle + quicksort(right)"}, + {"type": "code", + "content": "def quicksort(arr):\n if len(arr) <= 1:\n return arr\n \n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n \n return quicksort(left) + middle + quicksort(right)"}, {"type": "heading", "content": "复杂度对比", "level": 2}, - {"type": "table", "content": "| 算法 | 最好情况 | 平均情况 | 最坏情况 |\n|------|----------|----------|----------|\n| 快速排序 | O(n log n) | O(n log n) | O(n²) |\n| 归并排序 | O(n log n) | O(n log n) | O(n log n) |\n| 冒泡排序 | O(n) | O(n²) | O(n²) |"}, + {"type": "table", + "content": "| 算法 | 最好情况 | 平均情况 | 最坏情况 |\n|------|----------|----------|----------|\n| 快速排序 | O(n log n) | O(n log n) | O(n²) |\n| 归并排序 | O(n log n) | O(n log n) | O(n log n) |\n| 冒泡排序 | O(n) | O(n²) | O(n²) |"}, {"type": "equation-inline", "content": "T(n) = aT(n/b) + f(n)"}, {"type": "paragraph", "content": "其中 a ≥ 1, b > 1 是常数,f(n) 是正函数。"} ], @@ -279,67 +284,67 @@ def quicksort(arr): "content_type": "computer_science" } ] - + # 创建数据集 dataset = BenchmarkDataset(name="sample_dataset", description="示例评测数据集") - + for sample_data in samples: sample = DataSample.from_dict(sample_data) dataset.add_sample(sample) - + return dataset def demo_basic_mock_evaluation(): """演示基本评测流程""" - + print("=== WebMainBench 基本使用示例 ===\n") - + # 设置日志 setup_logging(level="INFO") - + # 1. 创建或加载数据集 print("1. 创建示例数据集...") dataset = create_sample_dataset() print(f"数据集包含 {len(dataset)} 个样本") print(f"数据集统计: {dataset.get_statistics()}\n") - + # 2. 保存数据集到文件 data_dir = Path("data") data_dir.mkdir(exist_ok=True) - + dataset_path = data_dir / "sample_dataset.jsonl" DataSaver.save_jsonl(dataset, dataset_path, include_results=False) print(f"数据集已保存到: {dataset_path}\n") - + # 3. 重新加载数据集 print("2. 重新加载数据集...") loaded_dataset = DataLoader.load_jsonl(dataset_path) print(f"加载的数据集包含 {len(loaded_dataset)} 个样本\n") - + # 4. 列出可用的抽取器 print("3. 可用的抽取器:") available_extractors = ExtractorFactory.list_available() for extractor_name in available_extractors: print(f" - {extractor_name}") print() - + # 5. 创建评测器 print("4. 创建评测器...") evaluator = Evaluator() print(f"可用的评测指标: {evaluator.metric_calculator.list_available_metrics()}\n") - + # 6. 创建一个模拟抽取器进行演示 print("5. 创建模拟抽取器...") - + from webmainbench.extractors import BaseExtractor, ExtractionResult - + class MockExtractor(BaseExtractor): """模拟抽取器,用于演示""" - + def _setup(self): pass - + def _extract_content(self, html, url=None): # 简单的模拟抽取逻辑 if "标题" in html: @@ -351,19 +356,19 @@ def _extract_content(self, html, url=None): else: content = "提取的内容" content_list = [{"type": "paragraph", "content": "提取的内容"}] - + return ExtractionResult( content=content, content_list=content_list, success=True, confidence_score=0.85 ) - + # 注册模拟抽取器 ExtractorFactory.register("mock", MockExtractor) mock_extractor = ExtractorFactory.create("mock") print("模拟抽取器已创建\n") - + # 7. 运行评测 print("6. 运行评测...") result = evaluator.evaluate( @@ -371,21 +376,21 @@ def _extract_content(self, html, url=None): extractor=mock_extractor, max_samples=2 # 限制样本数量用于演示 ) - + # 8. 显示结果 print("\n7. 评测结果:") print("=" * 50) formatted_results = format_results(result.to_dict()) print(formatted_results) - + # 9. 保存结果 results_dir = Path("results") results_dir.mkdir(exist_ok=True) - + results_path = results_dir / "mock_evaluation_results.json" DataSaver.save_evaluation_results(result, results_path) print(f"\n结果已保存到: {results_path}") - + # 10. 生成报告 report_path = results_dir / "mock_evaluation_report.csv" DataSaver.save_summary_report(result, report_path) @@ -394,18 +399,19 @@ def _extract_content(self, html, url=None): def demo_extractor_comparison(): """演示多抽取器对比""" - + print("\n=== 多抽取器对比演示 ===\n") - + # 创建数据集 dataset = create_sample_dataset() - + # 创建多个模拟抽取器 from webmainbench.extractors import BaseExtractor, ExtractionResult - + class ExtractorA(BaseExtractor): def _setup(self): pass + def _extract_content(self, html, url=None): return ExtractionResult( content="抽取器A的结果", @@ -413,10 +419,11 @@ def _extract_content(self, html, url=None): success=True, confidence_score=0.9 ) - + class ExtractorB(BaseExtractor): def _setup(self): pass + def _extract_content(self, html, url=None): return ExtractionResult( content="抽取器B的结果", @@ -424,33 +431,33 @@ def _extract_content(self, html, url=None): success=True, confidence_score=0.8 ) - + # 注册抽取器 ExtractorFactory.register("extractor_a", ExtractorA) ExtractorFactory.register("extractor_b", ExtractorB) - + # 运行对比 evaluator = Evaluator() extractors = ["extractor_a", "extractor_b"] - + results = evaluator.compare_extractors( dataset=dataset, extractors=extractors, max_samples=2 ) - + # 显示对比结果 print("对比结果:") print("-" * 40) for extractor_name, result in results.items(): overall_score = result.overall_metrics.get('overall', 0) print(f"{extractor_name}: {overall_score:.4f}") - + # 保存多抽取器对比榜单 all_results = [] for extractor_name, result in results.items(): all_results.append(result.to_dict()) - + results_dir = Path("results") results_dir.mkdir(exist_ok=True) leaderboard_path = results_dir / "leaderboard.csv" @@ -460,17 +467,17 @@ def _extract_content(self, html, url=None): def demo_llm_webkit_evaluation(): """演示LLM-WebKit抽取器的6项指标评测""" - + print("=== LLM-WebKit Extractor 6项指标评测示例 ===\n") - + # 设置日志 setup_logging(level="INFO") - + # 1. 创建包含各种内容类型的测试数据集 print("1. 创建包含多种内容类型的测试数据集...") - + samples = [] - + # 样本1: 包含文本和代码 samples.append(DataSample( id="text_code_sample", @@ -502,11 +509,12 @@ def hello_world(): groundtruth_content_list=[ {"type": "heading", "content": "Python编程示例", "level": 1}, {"type": "text", "content": "这是一段关于Python编程的介绍文本。"}, - {"type": "code", "content": "def hello_world():\n print(\"Hello, World!\")\n return True", "language": "python"}, + {"type": "code", "content": "def hello_world():\n print(\"Hello, World!\")\n return True", + "language": "python"}, {"type": "text", "content": "以上代码展示了一个简单的Python函数。"} ] )) - + # 样本2: 包含表格 samples.append(DataSample( id="table_sample", @@ -546,10 +554,11 @@ def hello_world(): | 产品B | 200 | 3000 |""", groundtruth_content_list=[ {"type": "heading", "content": "销售数据统计", "level": 2}, - {"type": "table", "content": "| 产品 | 销量 | 收入 |\n|------|------|------|\n| 产品A | 100 | 1000 |\n| 产品B | 200 | 3000 |"} + {"type": "table", + "content": "| 产品 | 销量 | 收入 |\n|------|------|------|\n| 产品A | 100 | 1000 |\n| 产品B | 200 | 3000 |"} ] )) - + # 样本3: 包含公式 samples.append(DataSample( id="formula_sample", @@ -577,74 +586,74 @@ def hello_world(): {"type": "formula", "content": "\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi}"} ] )) - + # 创建数据集并添加样本 dataset = BenchmarkDataset(name="llm_webkit_test", description="LLM-WebKit 6项指标测试数据集") for sample in samples: dataset.add_sample(sample) - + print(f"测试数据集包含 {len(dataset)} 个样本") print(f"样本类型: 文本+代码, 表格, 公式\n") - + # 2. 创建LLM-WebKit抽取器 print("2. 创建LLM-WebKit抽取器...") - + # 显示所有可用的抽取器 available_extractors = ExtractorFactory.list_available() print(f"可用的抽取器: {available_extractors}") - + # 直接创建LLM-WebKit抽取器,设置模型路径 config = { "model_path": "/Users/chupei/model/checkpoint-3296" } extractor = ExtractorFactory.create("llm-webkit", config=config) print(f"✅ LLM-WebKit抽取器创建成功,模型路径: {config['model_path']}") - + print() - + # 3. 创建评测器并显示所有可用指标 print("3. 创建评测器...") evaluator = Evaluator() available_metrics = evaluator.metric_calculator.list_available_metrics() print(f"✅ 可用的评测指标 ({len(available_metrics)}项):") - + # 按照6项指标分类显示 target_metrics = ["overall", "text_edit", "code_edit", "table_edit", "table_TEDS", "formula_edit"] - + for metric in target_metrics: if metric in available_metrics: print(f" ✅ {metric}") else: print(f" ❌ {metric} (未注册)") - + print() - + # 4. 运行评测 print("4. 开始评测...") print("=" * 60) - + result = evaluator.evaluate( dataset=dataset, extractor=extractor, max_samples=None # 评测所有样本 ) - + # 5. 显示详细的6项指标结果 print("\n5. 📊 6项指标详细评测结果:") print("=" * 60) - + results_dict = result.to_dict() - + # 从overall_metrics中提取指标结果 metrics = results_dict.get('overall_metrics', {}) - + # 按照指标分类显示 print(f"\n🏆 综合指标:") if 'overall' in metrics: print(f" overall (综合得分): {metrics['overall']:.4f}") else: print(" overall: 未计算") - + print(f"\n📝 文本相关指标:") if 'text_edit' in metrics: print(f" text_edit (文本编辑距离): {metrics['text_edit']:.4f}") @@ -654,7 +663,7 @@ def hello_world(): print(f" code_edit (代码编辑距离): {metrics['code_edit']:.4f}") else: print(" code_edit: 未计算") - + print(f"\n📊 表格相关指标:") if 'table_edit' in metrics: print(f" table_edit (表格编辑距离): {metrics['table_edit']:.4f}") @@ -664,37 +673,37 @@ def hello_world(): print(f" table_TEDS (表格结构相似度): {metrics['table_TEDS']:.4f}") else: print(" table_TEDS: 未计算") - + print(f"\n🧮 公式相关指标:") if 'formula_edit' in metrics: print(f" formula_edit (公式编辑距离): {metrics['formula_edit']:.4f}") else: print(" formula_edit: 未计算") - + print(f"\n📈 详细统计:") print(f" 总样本数: {len(dataset)}") success_count = len([s for s in results_dict.get('sample_results', []) if s.get('extraction_success', False)]) failure_count = len(dataset) - success_count print(f" 成功样本数: {success_count}") print(f" 失败样本数: {failure_count}") - + # 6. 保存结果到文件 print("\n" + "=" * 60) print("6. 保存评测结果...") - + results_dir = Path("results") results_dir.mkdir(exist_ok=True) - + # 保存详细结果 results_path = results_dir / "llm_webkit_evaluation_results.json" DataSaver.save_evaluation_results(result, results_path) # 直接传递result对象 print(f"✅ 详细结果已保存到: {results_path}") - + # 生成CSV报告 report_path = results_dir / "llm_webkit_evaluation_report.csv" DataSaver.save_summary_report(result, report_path) # 直接传递result对象 print(f"✅ CSV报告已保存到: {report_path}") - + print("\n" + "=" * 60) print("✅ LLM-WebKit 6项指标评测完成!") @@ -702,84 +711,84 @@ def hello_world(): def demo_dataset_with_extraction(): """演示保存带有抽取内容的数据集""" print("=== 演示:保存带有抽取内容的数据集 ===") - + from webmainbench import DataLoader, DataSaver, Evaluator, ExtractorFactory from pathlib import Path - + # 配置文件路径 data_dir = Path("data") dataset_path = data_dir / "sample_dataset.jsonl" # dataset_path = "/Users/chupei/Downloads/WebMainBench_dataset_merge_2549.jsonl" - + print(f"📂 数据集文件: {dataset_path}") - + # 🔧 创建llm-webkit抽取器(统一使用) extractor_config = {"model_path": "/Users/chupei/model/checkpoint-3296"} extractor = ExtractorFactory.create("llm-webkit", config=extractor_config) print(f"🤖 使用抽取器: {extractor.name}") - + # 创建评测器 evaluator = Evaluator() - + # 🔧 选择评测模式:内存模式 vs 批处理模式 USE_BATCHED_MODE = True # 设置为True使用批处理模式(适用于大数据集) - + if USE_BATCHED_MODE: print("🔄 使用批处理模式(内存优化)") - + # 🚀 批处理评测(适用于大数据集) result = evaluator.evaluate_batched( jsonl_file_path=dataset_path, extractor=extractor, # 直接传递extractor对象 - batch_size=10, # 小批次 - max_samples=20 # 演示用 + batch_size=10, # 小批次 + max_samples=20 # 演示用 ) print(f"✅ 批处理评测完成,总体得分: {result.overall_metrics.get('overall', 0):.4f}") - + # 为了保存带有抽取内容的数据集,需要重新加载原始数据集 # 注:这里只是短暂加载用于保存,不影响前面的内存优化评测 dataset = DataLoader.load_jsonl(dataset_path, include_results=False) dataset.name = result.dataset_name - + else: print("🔄 使用传统内存模式") - + # 从文件加载数据集 print(f"📂 从文件加载数据集: {dataset_path}") dataset = DataLoader.load_jsonl(dataset_path, include_results=False) dataset.name = "WebMainBench_with_extraction" dataset.description = "演示抽取内容保存的测试数据集" - + print(f"📊 加载数据集完成,包含 {len(dataset.samples)} 个样本") - + # 运行评测 result = evaluator.evaluate(dataset, extractor) - + print(f"✅ 评测完成,总体得分: {result.overall_metrics.get('overall', 0):.4f}") - + # 保存带有抽取内容的数据集 results_dir = Path("results") enriched_dataset_path = results_dir / f"{dataset.name}_with_{extractor.name}_extraction.jsonl" - + DataSaver.save_dataset_with_extraction( results=result, - dataset=dataset, + dataset=dataset, file_path=enriched_dataset_path, extractor_name=extractor.name ) - + print(f"💾 已保存带有抽取内容的数据集到: {enriched_dataset_path}") - + # 保存评测结果和摘要报告 evaluation_results_path = results_dir / f"{dataset.name}_{extractor.name}_evaluation_results.json" summary_report_path = results_dir / f"{dataset.name}_{extractor.name}_evaluation_report.csv" - + DataSaver.save_evaluation_results(result, evaluation_results_path) DataSaver.save_summary_report(result, summary_report_path) - + print(f"📊 已保存评测结果到: {evaluation_results_path}") print(f"📈 已保存摘要报告到: {summary_report_path}") - + # 显示保存的字段信息 print("\n📋 保存的新字段包括:") print(f" - {extractor.name}_content: 抽取的内容") @@ -788,6 +797,7 @@ def demo_dataset_with_extraction(): print(f" - {extractor.name}_time: 抽取耗时") print(f" - {extractor.name}_*_score: 各项指标分数") + if __name__ == "__main__": try: demo_basic_mock_evaluation() @@ -795,8 +805,9 @@ def demo_dataset_with_extraction(): demo_extractor_comparison() demo_dataset_with_extraction() # 演示保存带有抽取内容的数据集 print("\n✅ 示例运行完成!") - + except Exception as e: print(f"\n❌ 运行出错: {e}") import traceback - traceback.print_exc() \ No newline at end of file + + traceback.print_exc() \ No newline at end of file diff --git a/examples/demo.py b/examples/demo.py index 62339b5..1006116 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -1,14 +1,14 @@ from webmainbench import DataLoader, Evaluator, ExtractorFactory # 1. 加载评测数据集 -dataset = DataLoader.load_jsonl("data/sample_dataset.jsonl") +dataset = DataLoader.load_jsonl("../data/sample_dataset.jsonl") # 2. 创建抽取器 -extractor = ExtractorFactory.create("jina-ai") +extractor = ExtractorFactory.create("llm-webkit") # 3. 运行评测 evaluator = Evaluator() result = evaluator.evaluate(dataset, extractor) # 4. 查看结果 -print(f"Overall Score: {result}") \ No newline at end of file +print(f"Overall Score: {result}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..04708da --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +rapidFuzz diff --git a/results/leaderboard.csv b/results/leaderboard.csv index 8cc82c6..abdd754 100644 --- a/results/leaderboard.csv +++ b/results/leaderboard.csv @@ -1,3 +1,3 @@ -extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit -extractor_a,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062 -extractor_b,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062 +extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit +extractor_a,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062 +extractor_b,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062 diff --git a/results/llm_webkit_evaluation_report.csv b/results/llm_webkit_evaluation_report.csv index 4271205..770ebab 100644 --- a/results/llm_webkit_evaluation_report.csv +++ b/results/llm_webkit_evaluation_report.csv @@ -1,2 +1,2 @@ -extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit -llm-webkit,llm_webkit_test,3,1.0,0.8224,0.8293,0.6667,1.0,0.963,0.6531 +extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit +llm-webkit,llm_webkit_test,3,1.0,0.8224,0.8293,0.6667,1.0,0.963,0.6531 diff --git a/results/mock_evaluation_report.csv b/results/mock_evaluation_report.csv index 8fbc2bd..87c0e11 100644 --- a/results/mock_evaluation_report.csv +++ b/results/mock_evaluation_report.csv @@ -1,2 +1,2 @@ -extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit -mock,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062 +extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit +mock,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062 diff --git a/webmainbench/extractors/__init__.py b/webmainbench/extractors/__init__.py index a7e56a8..e94c225 100644 --- a/webmainbench/extractors/__init__.py +++ b/webmainbench/extractors/__init__.py @@ -9,6 +9,8 @@ from .llm_webkit_extractor import LlmWebkitExtractor from .jina_extractor import JinaExtractor + + __all__ = [ "BaseExtractor", "ExtractionResult", diff --git a/webmainbench/metrics/text_metrics.py b/webmainbench/metrics/text_metrics.py index fd4c89c..1ebfb5e 100644 --- a/webmainbench/metrics/text_metrics.py +++ b/webmainbench/metrics/text_metrics.py @@ -6,7 +6,7 @@ import difflib import re from .base import BaseMetric, MetricResult - +from rapidfuzz.distance import Levenshtein class EditDistanceMetric(BaseMetric): """Edit distance (Levenshtein distance) metric.""" @@ -62,23 +62,9 @@ def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> Metric def _levenshtein_distance(self, s1: str, s2: str) -> int: """Calculate Levenshtein distance between two strings.""" - if len(s1) < len(s2): - return self._levenshtein_distance(s2, s1) - - if len(s2) == 0: - return len(s1) - - previous_row = list(range(len(s2) + 1)) - for i, c1 in enumerate(s1): - current_row = [i + 1] - for j, c2 in enumerate(s2): - insertions = previous_row[j + 1] + 1 - deletions = current_row[j] + 1 - substitutions = previous_row[j] + (c1 != c2) - current_row.append(min(insertions, deletions, substitutions)) - previous_row = current_row - - return previous_row[-1] + + return Levenshtein.distance(s1, s2) + class BLEUMetric(BaseMetric):