diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index cb6f11d..e259983 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -35,6 +35,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .
+ pip install -r requirements.txt
pip install pytest pytest-cov pytest-xdist coverage
- name: Install optional dependencies (ignore failures)
diff --git a/examples/basic_usage.py b/examples/basic_usage.py
index b75d533..c8f44a8 100755
--- a/examples/basic_usage.py
+++ b/examples/basic_usage.py
@@ -9,14 +9,14 @@
# 导入 WebMainBench 模块
from webmainbench import (
DataLoader, DataSaver, BenchmarkDataset, DataSample,
- ExtractorFactory, Evaluator,
+ ExtractorFactory, Evaluator,
format_results, setup_logging
)
def create_sample_dataset():
"""创建示例数据集"""
-
+
# 创建示例数据 - 包含多种内容类型(代码、公式、表格等)
samples = [
{
@@ -51,7 +51,8 @@ def greet(name):
"groundtruth_content_list": [
{"type": "heading", "content": "Python编程教程", "level": 1},
{"type": "paragraph", "content": "这是一个Python基础教程,展示如何定义函数。"},
- {"type": "code", "content": 'def greet(name):\n """问候函数"""\n return f"Hello, {name}!"\n\n# 使用示例\nresult = greet("World")\nprint(result)'},
+ {"type": "code",
+ "content": 'def greet(name):\n """问候函数"""\n return f"Hello, {name}!"\n\n# 使用示例\nresult = greet("World")\nprint(result)'},
{"type": "paragraph", "content": "这个函数可以用来问候任何人。"}
],
"url": "https://python-tutorial.example.com/functions",
@@ -180,9 +181,11 @@ def greet(name):
{"type": "heading", "content": "数据分析报告", "level": 1},
{"type": "paragraph", "content": "以下是2024年第一季度的销售数据分析。"},
{"type": "heading", "content": "数据处理代码", "level": 2},
- {"type": "code", "content": "import pandas as pd\nimport numpy as np\n\n# 读取数据\ndf = pd.read_csv('sales_q1_2024.csv')\n\n# 计算统计信息\nmonthly_avg = df.groupby('month')['sales'].mean()\nprint(f\"平均销售额: {monthly_avg}\")"},
+ {"type": "code",
+ "content": "import pandas as pd\nimport numpy as np\n\n# 读取数据\ndf = pd.read_csv('sales_q1_2024.csv')\n\n# 计算统计信息\nmonthly_avg = df.groupby('month')['sales'].mean()\nprint(f\"平均销售额: {monthly_avg}\")"},
{"type": "heading", "content": "销售统计", "level": 2},
- {"type": "table", "content": "| 月份 | 销售额(万元) | 增长率 |\n|------|-------------|--------|\n| 1月 | 120.5 | +15.2% |\n| 2月 | 135.8 | +12.7% |\n| 3月 | 148.3 | +9.2% |"},
+ {"type": "table",
+ "content": "| 月份 | 销售额(万元) | 增长率 |\n|------|-------------|--------|\n| 1月 | 120.5 | +15.2% |\n| 2月 | 135.8 | +12.7% |\n| 3月 | 148.3 | +9.2% |"},
{"type": "paragraph", "content": "标准差公式:σ = √(Σ(xi - μ)² / n)"},
{"type": "paragraph", "content": "总体来看,第一季度销售表现良好,呈现稳定增长趋势。"}
],
@@ -208,12 +211,12 @@ def greet(name):
def quicksort(arr):
if len(arr) <= 1:
return arr
-
+
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
-
+
return quicksort(left) + middle + quicksort(right)
复杂度对比
@@ -235,12 +238,12 @@ def greet(name):
def quicksort(arr):
if len(arr) <= 1:
return arr
-
+
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
-
+
return quicksort(left) + middle + quicksort(right)
```
@@ -259,9 +262,11 @@ def quicksort(arr):
{"type": "heading", "content": "算法复杂度分析", "level": 1},
{"type": "paragraph", "content": "这里介绍常见算法的时间复杂度。"},
{"type": "heading", "content": "快速排序实现", "level": 2},
- {"type": "code", "content": "def quicksort(arr):\n if len(arr) <= 1:\n return arr\n \n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n \n return quicksort(left) + middle + quicksort(right)"},
+ {"type": "code",
+ "content": "def quicksort(arr):\n if len(arr) <= 1:\n return arr\n \n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n \n return quicksort(left) + middle + quicksort(right)"},
{"type": "heading", "content": "复杂度对比", "level": 2},
- {"type": "table", "content": "| 算法 | 最好情况 | 平均情况 | 最坏情况 |\n|------|----------|----------|----------|\n| 快速排序 | O(n log n) | O(n log n) | O(n²) |\n| 归并排序 | O(n log n) | O(n log n) | O(n log n) |\n| 冒泡排序 | O(n) | O(n²) | O(n²) |"},
+ {"type": "table",
+ "content": "| 算法 | 最好情况 | 平均情况 | 最坏情况 |\n|------|----------|----------|----------|\n| 快速排序 | O(n log n) | O(n log n) | O(n²) |\n| 归并排序 | O(n log n) | O(n log n) | O(n log n) |\n| 冒泡排序 | O(n) | O(n²) | O(n²) |"},
{"type": "equation-inline", "content": "T(n) = aT(n/b) + f(n)"},
{"type": "paragraph", "content": "其中 a ≥ 1, b > 1 是常数,f(n) 是正函数。"}
],
@@ -279,67 +284,67 @@ def quicksort(arr):
"content_type": "computer_science"
}
]
-
+
# 创建数据集
dataset = BenchmarkDataset(name="sample_dataset", description="示例评测数据集")
-
+
for sample_data in samples:
sample = DataSample.from_dict(sample_data)
dataset.add_sample(sample)
-
+
return dataset
def demo_basic_mock_evaluation():
"""演示基本评测流程"""
-
+
print("=== WebMainBench 基本使用示例 ===\n")
-
+
# 设置日志
setup_logging(level="INFO")
-
+
# 1. 创建或加载数据集
print("1. 创建示例数据集...")
dataset = create_sample_dataset()
print(f"数据集包含 {len(dataset)} 个样本")
print(f"数据集统计: {dataset.get_statistics()}\n")
-
+
# 2. 保存数据集到文件
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
-
+
dataset_path = data_dir / "sample_dataset.jsonl"
DataSaver.save_jsonl(dataset, dataset_path, include_results=False)
print(f"数据集已保存到: {dataset_path}\n")
-
+
# 3. 重新加载数据集
print("2. 重新加载数据集...")
loaded_dataset = DataLoader.load_jsonl(dataset_path)
print(f"加载的数据集包含 {len(loaded_dataset)} 个样本\n")
-
+
# 4. 列出可用的抽取器
print("3. 可用的抽取器:")
available_extractors = ExtractorFactory.list_available()
for extractor_name in available_extractors:
print(f" - {extractor_name}")
print()
-
+
# 5. 创建评测器
print("4. 创建评测器...")
evaluator = Evaluator()
print(f"可用的评测指标: {evaluator.metric_calculator.list_available_metrics()}\n")
-
+
# 6. 创建一个模拟抽取器进行演示
print("5. 创建模拟抽取器...")
-
+
from webmainbench.extractors import BaseExtractor, ExtractionResult
-
+
class MockExtractor(BaseExtractor):
"""模拟抽取器,用于演示"""
-
+
def _setup(self):
pass
-
+
def _extract_content(self, html, url=None):
# 简单的模拟抽取逻辑
if "标题" in html:
@@ -351,19 +356,19 @@ def _extract_content(self, html, url=None):
else:
content = "提取的内容"
content_list = [{"type": "paragraph", "content": "提取的内容"}]
-
+
return ExtractionResult(
content=content,
content_list=content_list,
success=True,
confidence_score=0.85
)
-
+
# 注册模拟抽取器
ExtractorFactory.register("mock", MockExtractor)
mock_extractor = ExtractorFactory.create("mock")
print("模拟抽取器已创建\n")
-
+
# 7. 运行评测
print("6. 运行评测...")
result = evaluator.evaluate(
@@ -371,21 +376,21 @@ def _extract_content(self, html, url=None):
extractor=mock_extractor,
max_samples=2 # 限制样本数量用于演示
)
-
+
# 8. 显示结果
print("\n7. 评测结果:")
print("=" * 50)
formatted_results = format_results(result.to_dict())
print(formatted_results)
-
+
# 9. 保存结果
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
-
+
results_path = results_dir / "mock_evaluation_results.json"
DataSaver.save_evaluation_results(result, results_path)
print(f"\n结果已保存到: {results_path}")
-
+
# 10. 生成报告
report_path = results_dir / "mock_evaluation_report.csv"
DataSaver.save_summary_report(result, report_path)
@@ -394,18 +399,19 @@ def _extract_content(self, html, url=None):
def demo_extractor_comparison():
"""演示多抽取器对比"""
-
+
print("\n=== 多抽取器对比演示 ===\n")
-
+
# 创建数据集
dataset = create_sample_dataset()
-
+
# 创建多个模拟抽取器
from webmainbench.extractors import BaseExtractor, ExtractionResult
-
+
class ExtractorA(BaseExtractor):
def _setup(self):
pass
+
def _extract_content(self, html, url=None):
return ExtractionResult(
content="抽取器A的结果",
@@ -413,10 +419,11 @@ def _extract_content(self, html, url=None):
success=True,
confidence_score=0.9
)
-
+
class ExtractorB(BaseExtractor):
def _setup(self):
pass
+
def _extract_content(self, html, url=None):
return ExtractionResult(
content="抽取器B的结果",
@@ -424,33 +431,33 @@ def _extract_content(self, html, url=None):
success=True,
confidence_score=0.8
)
-
+
# 注册抽取器
ExtractorFactory.register("extractor_a", ExtractorA)
ExtractorFactory.register("extractor_b", ExtractorB)
-
+
# 运行对比
evaluator = Evaluator()
extractors = ["extractor_a", "extractor_b"]
-
+
results = evaluator.compare_extractors(
dataset=dataset,
extractors=extractors,
max_samples=2
)
-
+
# 显示对比结果
print("对比结果:")
print("-" * 40)
for extractor_name, result in results.items():
overall_score = result.overall_metrics.get('overall', 0)
print(f"{extractor_name}: {overall_score:.4f}")
-
+
# 保存多抽取器对比榜单
all_results = []
for extractor_name, result in results.items():
all_results.append(result.to_dict())
-
+
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
leaderboard_path = results_dir / "leaderboard.csv"
@@ -460,17 +467,17 @@ def _extract_content(self, html, url=None):
def demo_llm_webkit_evaluation():
"""演示LLM-WebKit抽取器的6项指标评测"""
-
+
print("=== LLM-WebKit Extractor 6项指标评测示例 ===\n")
-
+
# 设置日志
setup_logging(level="INFO")
-
+
# 1. 创建包含各种内容类型的测试数据集
print("1. 创建包含多种内容类型的测试数据集...")
-
+
samples = []
-
+
# 样本1: 包含文本和代码
samples.append(DataSample(
id="text_code_sample",
@@ -502,11 +509,12 @@ def hello_world():
groundtruth_content_list=[
{"type": "heading", "content": "Python编程示例", "level": 1},
{"type": "text", "content": "这是一段关于Python编程的介绍文本。"},
- {"type": "code", "content": "def hello_world():\n print(\"Hello, World!\")\n return True", "language": "python"},
+ {"type": "code", "content": "def hello_world():\n print(\"Hello, World!\")\n return True",
+ "language": "python"},
{"type": "text", "content": "以上代码展示了一个简单的Python函数。"}
]
))
-
+
# 样本2: 包含表格
samples.append(DataSample(
id="table_sample",
@@ -546,10 +554,11 @@ def hello_world():
| 产品B | 200 | 3000 |""",
groundtruth_content_list=[
{"type": "heading", "content": "销售数据统计", "level": 2},
- {"type": "table", "content": "| 产品 | 销量 | 收入 |\n|------|------|------|\n| 产品A | 100 | 1000 |\n| 产品B | 200 | 3000 |"}
+ {"type": "table",
+ "content": "| 产品 | 销量 | 收入 |\n|------|------|------|\n| 产品A | 100 | 1000 |\n| 产品B | 200 | 3000 |"}
]
))
-
+
# 样本3: 包含公式
samples.append(DataSample(
id="formula_sample",
@@ -577,74 +586,74 @@ def hello_world():
{"type": "formula", "content": "\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi}"}
]
))
-
+
# 创建数据集并添加样本
dataset = BenchmarkDataset(name="llm_webkit_test", description="LLM-WebKit 6项指标测试数据集")
for sample in samples:
dataset.add_sample(sample)
-
+
print(f"测试数据集包含 {len(dataset)} 个样本")
print(f"样本类型: 文本+代码, 表格, 公式\n")
-
+
# 2. 创建LLM-WebKit抽取器
print("2. 创建LLM-WebKit抽取器...")
-
+
# 显示所有可用的抽取器
available_extractors = ExtractorFactory.list_available()
print(f"可用的抽取器: {available_extractors}")
-
+
# 直接创建LLM-WebKit抽取器,设置模型路径
config = {
"model_path": "/Users/chupei/model/checkpoint-3296"
}
extractor = ExtractorFactory.create("llm-webkit", config=config)
print(f"✅ LLM-WebKit抽取器创建成功,模型路径: {config['model_path']}")
-
+
print()
-
+
# 3. 创建评测器并显示所有可用指标
print("3. 创建评测器...")
evaluator = Evaluator()
available_metrics = evaluator.metric_calculator.list_available_metrics()
print(f"✅ 可用的评测指标 ({len(available_metrics)}项):")
-
+
# 按照6项指标分类显示
target_metrics = ["overall", "text_edit", "code_edit", "table_edit", "table_TEDS", "formula_edit"]
-
+
for metric in target_metrics:
if metric in available_metrics:
print(f" ✅ {metric}")
else:
print(f" ❌ {metric} (未注册)")
-
+
print()
-
+
# 4. 运行评测
print("4. 开始评测...")
print("=" * 60)
-
+
result = evaluator.evaluate(
dataset=dataset,
extractor=extractor,
max_samples=None # 评测所有样本
)
-
+
# 5. 显示详细的6项指标结果
print("\n5. 📊 6项指标详细评测结果:")
print("=" * 60)
-
+
results_dict = result.to_dict()
-
+
# 从overall_metrics中提取指标结果
metrics = results_dict.get('overall_metrics', {})
-
+
# 按照指标分类显示
print(f"\n🏆 综合指标:")
if 'overall' in metrics:
print(f" overall (综合得分): {metrics['overall']:.4f}")
else:
print(" overall: 未计算")
-
+
print(f"\n📝 文本相关指标:")
if 'text_edit' in metrics:
print(f" text_edit (文本编辑距离): {metrics['text_edit']:.4f}")
@@ -654,7 +663,7 @@ def hello_world():
print(f" code_edit (代码编辑距离): {metrics['code_edit']:.4f}")
else:
print(" code_edit: 未计算")
-
+
print(f"\n📊 表格相关指标:")
if 'table_edit' in metrics:
print(f" table_edit (表格编辑距离): {metrics['table_edit']:.4f}")
@@ -664,37 +673,37 @@ def hello_world():
print(f" table_TEDS (表格结构相似度): {metrics['table_TEDS']:.4f}")
else:
print(" table_TEDS: 未计算")
-
+
print(f"\n🧮 公式相关指标:")
if 'formula_edit' in metrics:
print(f" formula_edit (公式编辑距离): {metrics['formula_edit']:.4f}")
else:
print(" formula_edit: 未计算")
-
+
print(f"\n📈 详细统计:")
print(f" 总样本数: {len(dataset)}")
success_count = len([s for s in results_dict.get('sample_results', []) if s.get('extraction_success', False)])
failure_count = len(dataset) - success_count
print(f" 成功样本数: {success_count}")
print(f" 失败样本数: {failure_count}")
-
+
# 6. 保存结果到文件
print("\n" + "=" * 60)
print("6. 保存评测结果...")
-
+
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
-
+
# 保存详细结果
results_path = results_dir / "llm_webkit_evaluation_results.json"
DataSaver.save_evaluation_results(result, results_path) # 直接传递result对象
print(f"✅ 详细结果已保存到: {results_path}")
-
+
# 生成CSV报告
report_path = results_dir / "llm_webkit_evaluation_report.csv"
DataSaver.save_summary_report(result, report_path) # 直接传递result对象
print(f"✅ CSV报告已保存到: {report_path}")
-
+
print("\n" + "=" * 60)
print("✅ LLM-WebKit 6项指标评测完成!")
@@ -702,84 +711,84 @@ def hello_world():
def demo_dataset_with_extraction():
"""演示保存带有抽取内容的数据集"""
print("=== 演示:保存带有抽取内容的数据集 ===")
-
+
from webmainbench import DataLoader, DataSaver, Evaluator, ExtractorFactory
from pathlib import Path
-
+
# 配置文件路径
data_dir = Path("data")
dataset_path = data_dir / "sample_dataset.jsonl"
# dataset_path = "/Users/chupei/Downloads/WebMainBench_dataset_merge_2549.jsonl"
-
+
print(f"📂 数据集文件: {dataset_path}")
-
+
# 🔧 创建llm-webkit抽取器(统一使用)
extractor_config = {"model_path": "/Users/chupei/model/checkpoint-3296"}
extractor = ExtractorFactory.create("llm-webkit", config=extractor_config)
print(f"🤖 使用抽取器: {extractor.name}")
-
+
# 创建评测器
evaluator = Evaluator()
-
+
# 🔧 选择评测模式:内存模式 vs 批处理模式
USE_BATCHED_MODE = True # 设置为True使用批处理模式(适用于大数据集)
-
+
if USE_BATCHED_MODE:
print("🔄 使用批处理模式(内存优化)")
-
+
# 🚀 批处理评测(适用于大数据集)
result = evaluator.evaluate_batched(
jsonl_file_path=dataset_path,
extractor=extractor, # 直接传递extractor对象
- batch_size=10, # 小批次
- max_samples=20 # 演示用
+ batch_size=10, # 小批次
+ max_samples=20 # 演示用
)
print(f"✅ 批处理评测完成,总体得分: {result.overall_metrics.get('overall', 0):.4f}")
-
+
# 为了保存带有抽取内容的数据集,需要重新加载原始数据集
# 注:这里只是短暂加载用于保存,不影响前面的内存优化评测
dataset = DataLoader.load_jsonl(dataset_path, include_results=False)
dataset.name = result.dataset_name
-
+
else:
print("🔄 使用传统内存模式")
-
+
# 从文件加载数据集
print(f"📂 从文件加载数据集: {dataset_path}")
dataset = DataLoader.load_jsonl(dataset_path, include_results=False)
dataset.name = "WebMainBench_with_extraction"
dataset.description = "演示抽取内容保存的测试数据集"
-
+
print(f"📊 加载数据集完成,包含 {len(dataset.samples)} 个样本")
-
+
# 运行评测
result = evaluator.evaluate(dataset, extractor)
-
+
print(f"✅ 评测完成,总体得分: {result.overall_metrics.get('overall', 0):.4f}")
-
+
# 保存带有抽取内容的数据集
results_dir = Path("results")
enriched_dataset_path = results_dir / f"{dataset.name}_with_{extractor.name}_extraction.jsonl"
-
+
DataSaver.save_dataset_with_extraction(
results=result,
- dataset=dataset,
+ dataset=dataset,
file_path=enriched_dataset_path,
extractor_name=extractor.name
)
-
+
print(f"💾 已保存带有抽取内容的数据集到: {enriched_dataset_path}")
-
+
# 保存评测结果和摘要报告
evaluation_results_path = results_dir / f"{dataset.name}_{extractor.name}_evaluation_results.json"
summary_report_path = results_dir / f"{dataset.name}_{extractor.name}_evaluation_report.csv"
-
+
DataSaver.save_evaluation_results(result, evaluation_results_path)
DataSaver.save_summary_report(result, summary_report_path)
-
+
print(f"📊 已保存评测结果到: {evaluation_results_path}")
print(f"📈 已保存摘要报告到: {summary_report_path}")
-
+
# 显示保存的字段信息
print("\n📋 保存的新字段包括:")
print(f" - {extractor.name}_content: 抽取的内容")
@@ -788,6 +797,7 @@ def demo_dataset_with_extraction():
print(f" - {extractor.name}_time: 抽取耗时")
print(f" - {extractor.name}_*_score: 各项指标分数")
+
if __name__ == "__main__":
try:
demo_basic_mock_evaluation()
@@ -795,8 +805,9 @@ def demo_dataset_with_extraction():
demo_extractor_comparison()
demo_dataset_with_extraction() # 演示保存带有抽取内容的数据集
print("\n✅ 示例运行完成!")
-
+
except Exception as e:
print(f"\n❌ 运行出错: {e}")
import traceback
- traceback.print_exc()
\ No newline at end of file
+
+ traceback.print_exc()
\ No newline at end of file
diff --git a/examples/demo.py b/examples/demo.py
index 62339b5..1006116 100644
--- a/examples/demo.py
+++ b/examples/demo.py
@@ -1,14 +1,14 @@
from webmainbench import DataLoader, Evaluator, ExtractorFactory
# 1. 加载评测数据集
-dataset = DataLoader.load_jsonl("data/sample_dataset.jsonl")
+dataset = DataLoader.load_jsonl("../data/sample_dataset.jsonl")
# 2. 创建抽取器
-extractor = ExtractorFactory.create("jina-ai")
+extractor = ExtractorFactory.create("llm-webkit")
# 3. 运行评测
evaluator = Evaluator()
result = evaluator.evaluate(dataset, extractor)
# 4. 查看结果
-print(f"Overall Score: {result}")
\ No newline at end of file
+print(f"Overall Score: {result}")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..04708da
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1 @@
+rapidFuzz
diff --git a/results/leaderboard.csv b/results/leaderboard.csv
index 8cc82c6..abdd754 100644
--- a/results/leaderboard.csv
+++ b/results/leaderboard.csv
@@ -1,3 +1,3 @@
-extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit
-extractor_a,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062
-extractor_b,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062
+extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit
+extractor_a,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062
+extractor_b,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062
diff --git a/results/llm_webkit_evaluation_report.csv b/results/llm_webkit_evaluation_report.csv
index 4271205..770ebab 100644
--- a/results/llm_webkit_evaluation_report.csv
+++ b/results/llm_webkit_evaluation_report.csv
@@ -1,2 +1,2 @@
-extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit
-llm-webkit,llm_webkit_test,3,1.0,0.8224,0.8293,0.6667,1.0,0.963,0.6531
+extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit
+llm-webkit,llm_webkit_test,3,1.0,0.8224,0.8293,0.6667,1.0,0.963,0.6531
diff --git a/results/mock_evaluation_report.csv b/results/mock_evaluation_report.csv
index 8fbc2bd..87c0e11 100644
--- a/results/mock_evaluation_report.csv
+++ b/results/mock_evaluation_report.csv
@@ -1,2 +1,2 @@
-extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit
-mock,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062
+extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit
+mock,sample_dataset,2,1.0,0.5012,0.5,1.0,0.5,0.5,0.0062
diff --git a/webmainbench/extractors/__init__.py b/webmainbench/extractors/__init__.py
index a7e56a8..e94c225 100644
--- a/webmainbench/extractors/__init__.py
+++ b/webmainbench/extractors/__init__.py
@@ -9,6 +9,8 @@
from .llm_webkit_extractor import LlmWebkitExtractor
from .jina_extractor import JinaExtractor
+
+
__all__ = [
"BaseExtractor",
"ExtractionResult",
diff --git a/webmainbench/metrics/text_metrics.py b/webmainbench/metrics/text_metrics.py
index fd4c89c..1ebfb5e 100644
--- a/webmainbench/metrics/text_metrics.py
+++ b/webmainbench/metrics/text_metrics.py
@@ -6,7 +6,7 @@
import difflib
import re
from .base import BaseMetric, MetricResult
-
+from rapidfuzz.distance import Levenshtein
class EditDistanceMetric(BaseMetric):
"""Edit distance (Levenshtein distance) metric."""
@@ -62,23 +62,9 @@ def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> Metric
def _levenshtein_distance(self, s1: str, s2: str) -> int:
"""Calculate Levenshtein distance between two strings."""
- if len(s1) < len(s2):
- return self._levenshtein_distance(s2, s1)
-
- if len(s2) == 0:
- return len(s1)
-
- previous_row = list(range(len(s2) + 1))
- for i, c1 in enumerate(s1):
- current_row = [i + 1]
- for j, c2 in enumerate(s2):
- insertions = previous_row[j + 1] + 1
- deletions = current_row[j] + 1
- substitutions = previous_row[j] + (c1 != c2)
- current_row.append(min(insertions, deletions, substitutions))
- previous_row = current_row
-
- return previous_row[-1]
+
+ return Levenshtein.distance(s1, s2)
+
class BLEUMetric(BaseMetric):