diff --git a/docs/LLM_WEBKIT.md b/docs/LLM_WEBKIT.md new file mode 100644 index 0000000..9631f37 --- /dev/null +++ b/docs/LLM_WEBKIT.md @@ -0,0 +1,284 @@ +# LLM-WebKit Extractor 使用指南 + +## 概述 + +LLM-WebKit Extractor集成了大语言模型(LLM)推理能力,能够智能地理解HTML结构并准确提取主要内容。 + +## 安装依赖 + +```bash +# 基础依赖 +pip install torch transformers + +# VLLM推理引擎 +pip install vllm + +# LLM-WebKit HTML处理 +pip install llm_web_kit + +# 可选:加速库 +pip install flash-attn # GPU加速 +``` + +## 基本使用 + +### 1. 创建Extractor + +```python +from webmainbench.extractors import ExtractorFactory + +# 使用默认配置 +extractor = ExtractorFactory.create("llm-webkit") + +# 使用自定义配置 +config = { + "model_path": "/Users/chupei/model/checkpoint-3296", + "use_logits_processor": True, + "temperature": 0.0, + "max_item_count": 500 +} +extractor = ExtractorFactory.create("llm-webkit", config=config) +``` + +### 2. 提取内容 + +```python +html_content = """ + + + +
主要文章内容
+ + + +""" + +result = extractor.extract(html_content) + +if result.success: + print(f"提取的内容: {result.content}") + print(f"置信度: {result.confidence_score}") + print(f"分类结果: {result.metadata['classification_result']}") +else: + print(f"提取失败: {result.error_message}") +``` + +## 配置选项 + +### LLMInferenceConfig 参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `model_path` | str | `"/share/liukaiwen/models/qwen3-0.6b/checkpoint-3296"` | LLM模型路径 | +| `use_logits_processor` | bool | `True` | 是否启用JSON格式约束 | +| `max_tokens` | int | `32768` | 最大输入token数 | +| `temperature` | float | `0.0` | 采样温度(0=确定性输出) | +| `top_p` | float | `0.95` | 核采样参数 | +| `max_output_tokens` | int | `8192` | 最大输出token数 | +| `tensor_parallel_size` | int | `1` | 张量并行大小 | +| `dtype` | str | `"bfloat16"` | 模型精度 | +| `max_item_count` | int | `1000` | 处理的最大item数量 | + +### 模式配置示例 + +#### 快速模式(适合批量处理) +```python +fast_config = { + "use_logits_processor": False, # 禁用格式约束提高速度 + "temperature": 0.0, + "max_item_count": 200, + "max_output_tokens": 2048, + "dtype": "float16" # 更快的精度 +} +``` + +#### 精确模式(适合高质量提取) +```python +precise_config = { + "use_logits_processor": True, # 启用格式约束 + "temperature": 0.0, + "max_item_count": 1000, + "max_output_tokens": 8192, + "dtype": "bfloat16" # 更好的精度 +} +``` + +#### 分布式模式(多GPU) +```python +distributed_config = { + "tensor_parallel_size": 4, # 使用4个GPU + "dtype": "bfloat16", + "max_item_count": 2000, # 可以处理更复杂的HTML +} +``` + +## 工作流程详解 + +### 1. HTML预处理 +```python +# 使用llm_web_kit简化HTML结构 +simplified_html, raw_tag_html, _ = simplify_html(original_html) +``` + +### 2. 复杂度检查 +```python +item_count = simplified_html.count('_item_id') +if item_count > max_item_count: + # 跳过过于复杂的HTML + return error_result +``` + +### 3. LLM推理 +```python +# 创建分类提示 +prompt = create_classification_prompt(simplified_html) + +# 使用VLLM生成分类结果 +output = model.generate(prompt, sampling_params) +classification = parse_json_output(output) +``` + +### 4. 内容重建 +```python +# 根据分类结果重建主要内容 +main_content, content_list = reconstruct_content( + original_html, classification +) +``` + +## 提示工程 + +### 分类标准 + +**主要内容 ("main")**: +- 文章正文、博客内容 +- 问答的问题和答案 +- 论坛的主要讨论内容 +- 嵌入的相关图片和媒体 + +**辅助内容 ("other")**: +- 导航菜单、侧边栏、页脚 +- 元数据(作者、时间、浏览量等) +- 广告和推广内容 +- 相关推荐和建议内容 + +### 自定义提示模板 + +如果需要修改分类逻辑,可以继承类并重写提示模板: + +```python +class CustomLlmWebkitExtractor(LlmWebkitExtractor): + CLASSIFICATION_PROMPT = """ + 您的自定义分类提示... + 输入HTML: {alg_html} + """ +``` + +## 性能优化建议 + +### 1. 模型选择 +- **小模型** (0.5B-1B): 适合快速批处理,准确率略低 +- **中等模型** (3B-7B): 平衡性能和准确率 +- **大模型** (13B+): 最高准确率,适合高质量需求 + +### 2. 硬件配置 +```python +# 单GPU配置 +config = { + "tensor_parallel_size": 1, + "dtype": "bfloat16", # A100/H100推荐 + # "dtype": "float16", # V100/RTX推荐 +} + +# 多GPU配置 +config = { + "tensor_parallel_size": 4, # 4个GPU + "dtype": "bfloat16", +} +``` + +### 3. 批处理优化 +```python +# 预加载模型避免重复初始化 +extractor = ExtractorFactory.create("llm-webkit", config) + +# 批量处理 +for html in html_list: + result = extractor.extract(html) + process_result(result) +``` + +## 故障排除 + +### 常见问题 + +1. **模型加载失败** + ``` + RuntimeError: Failed to load LLM model + ``` + - 检查模型路径是否正确 + - 确保有足够的GPU内存 + - 验证模型格式兼容性 + +2. **JSON解析错误** + ``` + Warning: LLM output is not valid JSON + ``` + - 启用 `use_logits_processor=True` + - 检查提示模板格式 + - 降低temperature增加确定性 + +3. **内存不足** + ``` + CUDA out of memory + ``` + - 减少 `max_item_count` + - 降低 `max_output_tokens` + - 使用 `dtype="float16"` + - 增加 `tensor_parallel_size` + +4. **处理速度慢** + - 禁用 `use_logits_processor` + - 减少 `max_output_tokens` + - 使用更小的模型 + - 增加GPU并行度 + +### 调试技巧 + +```python +# 启用详细日志 +import logging +logging.basicConfig(level=logging.DEBUG) + +# 检查分类结果 +result = extractor.extract(html) +if result.success: + print("分类详情:", result.metadata['classification_result']) + print("LLM原始输出:", result.metadata['llm_output']) +``` + +## 集成示例 + +### 与WebMainBench评测框架集成 + +```python +from webmainbench.evaluator import Evaluator +from webmainbench.data import BenchmarkDataset + +# 创建数据集 +dataset = BenchmarkDataset.from_file("test_data.jsonl") + +# 配置LLM-WebKit extractor +config = { + "model_path": "/path/to/model", + "use_logits_processor": True, + "max_item_count": 500 +} + +# 运行评测 +evaluator = Evaluator(extractor_name="llm-webkit", extractor_config=config) +results = evaluator.evaluate(dataset) + +print(f"平均得分: {results['overall_score']}") +print(f"处理速度: {results['processing_speed']} samples/s") +``` diff --git a/TEDS.md b/docs/TEDS.md similarity index 100% rename from TEDS.md rename to docs/TEDS.md diff --git a/examples/llm_webkit_usage.py b/examples/llm_webkit_usage.py new file mode 100644 index 0000000..4300f55 --- /dev/null +++ b/examples/llm_webkit_usage.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +LLM-WebKit Extractor使用示例 + +本示例展示如何使用集成了VLLM推理能力的LLM-WebKit extractor。 +""" + +import time +from webmainbench.extractors import ExtractorFactory + + +def main(): + print("🚀 LLM-WebKit Extractor 使用示例\n") + + # 1. 创建带有自定义配置的extractor + config = { + "model_path": "/Users/chupei/model/checkpoint-3296", # 替换为您的模型路径 + "use_logits_processor": True, # 启用JSON格式约束 + "temperature": 0.0, # 确定性输出 + "max_item_count": 500, # 处理的最大item数量 + "max_output_tokens": 4096, # 最大输出token数 + "dtype": "bfloat16", # 模型精度 + "tensor_parallel_size": 1 # 张量并行大小 + } + + try: + extractor = ExtractorFactory.create("llm-webkit", config=config) + print(f"✅ Extractor创建成功: {extractor.description}") + print(f"📋 版本: {extractor.version}") + print(f"⚙️ 配置: {extractor.inference_config.__dict__}\n") + + except Exception as e: + print(f"❌ Extractor创建失败: {e}") + print("💡 请确保已安装所需依赖:") + print(" pip install vllm transformers torch llm_web_kit") + return + + # 2. 准备测试HTML(包含_item_id属性的结构化HTML) + test_html = """ + + + 测试文章 - 人工智能的发展趋势 + + + + +
+

人工智能的发展趋势

+

作者:张三 | 发布时间:2024-01-15 | 阅读量:1,234

+
+ +
+
+

人工智能(AI)技术正在快速发展,对各行各业产生深远影响。本文将探讨AI的主要发展趋势和未来展望。

+ +

1. 机器学习的进步

+

深度学习和大语言模型的突破使得AI系统能够理解和生成更自然的语言,在对话、翻译、创作等领域表现出色。

+ +

2. 自动化应用

+

从制造业的机器人到软件开发的代码生成,AI正在各个领域实现流程自动化,提高效率并降低成本。

+ +

3. 个性化服务

+

基于用户数据的个性化推荐和服务正变得越来越精准,为用户提供更好的体验。

+
+
+ + + + + + + """ + + # 3. 执行内容提取 + print("🔍 开始内容提取...") + start_time = time.time() + + try: + result = extractor.extract(test_html) + end_time = time.time() + + print(f"⏱️ 提取耗时: {end_time - start_time:.2f}秒\n") + + # 4. 显示提取结果 + if result.success: + print("✅ 内容提取成功!\n") + + print("📄 提取的主要内容:") + print("=" * 50) + print(result.content[:500] + "..." if len(result.content) > 500 else result.content) + print("=" * 50) + + print(f"\n📊 提取统计:") + print(f" • 内容长度: {len(result.content)} 字符") + print(f" • 置信度: {result.confidence_score:.3f}") + print(f" • 标题: {result.title}") + print(f" • 语言: {result.language}") + print(f" • 提取时间: {result.extraction_time:.3f}秒") + + if result.content_list: + print(f" • 结构化内容块: {len(result.content_list)}个") + for i, item in enumerate(result.content_list[:3]): # 显示前3个 + print(f" [{i+1}] {item.get('type', 'unknown')}: {item.get('content', '')[:50]}...") + + else: + print("❌ 内容提取失败") + print(f"错误信息: {result.error_message}") + if result.error_traceback: + print(f"错误详情:\n{result.error_traceback}") + + except Exception as e: + print(f"❌ 提取过程中发生异常: {e}") + + print("\n🎯 高级功能说明:") + print("• 智能分类: 使用LLM理解HTML元素语义,准确区分主要内容和辅助内容") + print("• 格式约束: 通过logits processor确保LLM输出有效的JSON格式") + print("• 性能优化: 自动跳过过于复杂的HTML,支持延迟加载模型") + print("• 详细反馈: 提供分类结果、置信度和性能指标") + + +if __name__ == "__main__": + main() + + print("\n💡 使用提示:") + print("1. 确保已安装所需依赖: vllm, transformers, torch, llm_web_kit") + print("2. 设置正确的模型路径") + print("3. 根据硬件资源调整tensor_parallel_size和dtype") + print("4. 对于大规模HTML,适当调整max_item_count限制") + print("5. 使用use_logits_processor=True确保输出格式可靠性") \ No newline at end of file diff --git a/setup.py b/setup.py index 9f086df..ee1cb91 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,8 @@ "lxml>=4.9.0", "jsonlines>=3.1.0", "requests>=2.28.0", + "beautifulsoup4==4.12.0", + "numpy>=1.21.0,<2.0.0", # 避免NumPy 2.x兼容性问题 ], extras_require={ "all": [ @@ -42,6 +44,12 @@ "rouge-score>=0.1.2", "unstructured>=0.10.0", ], + "llm": [ + "torch>=2.0.0", + "transformers>=4.30.0", + "vllm>=0.4.0", + "llm_web_kit>=3.0.0", + ], "nlp": [ "nltk>=3.8", "rouge-score>=0.1.2", diff --git a/webmainbench/extractors/llm_webkit_extractor.py b/webmainbench/extractors/llm_webkit_extractor.py index b50cb57..b49df9b 100644 --- a/webmainbench/extractors/llm_webkit_extractor.py +++ b/webmainbench/extractors/llm_webkit_extractor.py @@ -1,102 +1,730 @@ """ -LLM-WebKit extractor implementation. +LLM-WebKit extractor implementation with advanced LLM inference. """ -from typing import Dict, Any, Optional +import json +import re +import time +from typing import Dict, Any, Optional, List +from enum import Enum +from dataclasses import dataclass +import torch + from .base import BaseExtractor, ExtractionResult from .factory import extractor +@dataclass +class LLMInferenceConfig: + """Configuration for LLM inference.""" + model_path: str = "/path/to/your/model" + use_logits_processor: bool = True + max_tokens: int = 32768 # 最大输入token数 + temperature: float = 0.0 + top_p: float = 0.95 + max_output_tokens: int = 8192 # 最大输出token数 + tensor_parallel_size: int = 1 # 张量并行大小 + dtype: str = "bfloat16" # 数据类型 + max_item_count: int = 1000 # 最大item数量 + gpu_memory_utilization: float = 0.8 # GPU内存利用率 + enforce_eager: bool = True # 使用eager模式 + + +class TokenState(Enum): + """Token states for JSON format enforcement.""" + Left_bracket = 0 + Right_bracket = 1 + Space_quote = 2 + Quote_colon_quote = 3 + Quote_comma = 4 + Main_other = 5 + Number = 6 + + +class TokenStateManager: + """Manages token states to ensure valid JSON output.""" + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + token_id_map = { + TokenState.Left_bracket: ["{"], + TokenState.Right_bracket: ["}"], + TokenState.Space_quote: [' "'], + TokenState.Quote_colon_quote: ['":"'], + TokenState.Quote_comma: ['",'], + TokenState.Main_other: ["main", "other"], + TokenState.Number: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + } + self.token_id_map = {k: [self.tokenizer.encode(v)[0] for v in token_id_map[k]] for k in token_id_map} + + def mask_other_logits(self, logits: torch.Tensor, remained_ids: List[int]): + """Mask logits to only allow specific token IDs.""" + remained_logits = {ids: logits[ids].item() for ids in remained_ids} + new_logits = torch.ones_like(logits) * -float('inf') + for id in remained_ids: + new_logits[id] = remained_logits[id] + return new_logits + + def calc_max_count(self, prompt_token_ids: List[int]): + """Calculate maximum count of items from prompt.""" + pattern_list = [716, 1203, 842, 428] + for idx in range(len(prompt_token_ids) - len(pattern_list), -1, -1): + if all(prompt_token_ids[idx + i] == pattern_list[i] for i in range(len(pattern_list))): + num_idx = idx + len(pattern_list) + num_ids = [] + while num_idx < len(prompt_token_ids) and prompt_token_ids[num_idx] in self.token_id_map[TokenState.Number]: + num_ids.append(prompt_token_ids[num_idx]) + num_idx += 1 + return int(self.tokenizer.decode(num_ids)) + return 1 + + def find_last_complete_number(self, input_ids: List[int]): + """Find the last complete number in input IDs.""" + if not input_ids: + return -1, "null", -1 + + tail_number_ids = [] + last_idx = len(input_ids) - 1 + while last_idx >= 0 and input_ids[last_idx] in self.token_id_map[TokenState.Number]: + tail_number_ids.insert(0, input_ids[last_idx]) + last_idx -= 1 + + tail_number = int(self.tokenizer.decode(tail_number_ids)) if tail_number_ids else -1 + + while last_idx >= 0 and input_ids[last_idx] not in self.token_id_map[TokenState.Number]: + last_idx -= 1 + + if last_idx < 0: + return tail_number, "tail", tail_number + + last_number_ids = [] + while last_idx >= 0 and input_ids[last_idx] in self.token_id_map[TokenState.Number]: + last_number_ids.insert(0, input_ids[last_idx]) + last_idx -= 1 + + last_number = int(self.tokenizer.decode(last_number_ids)) + + if tail_number == last_number + 1: + return tail_number, "tail", tail_number + return last_number, "non_tail", tail_number + + def process_logit(self, prompt_token_ids: List[int], input_ids: List[int], logits: torch.Tensor): + """Process logits to enforce JSON format.""" + if not input_ids: + return self.mask_other_logits(logits, self.token_id_map[TokenState.Left_bracket]) + + last_token = input_ids[-1] + + if last_token == self.token_id_map[TokenState.Right_bracket][0]: + return self.mask_other_logits(logits, [151645]) + elif last_token == self.token_id_map[TokenState.Left_bracket][0]: + return self.mask_other_logits(logits, self.token_id_map[TokenState.Space_quote]) + elif last_token == self.token_id_map[TokenState.Space_quote][0]: + last_number, _, _ = self.find_last_complete_number(input_ids) + if last_number == -1: + next_char = '1' + else: + next_char = str(last_number + 1)[0] + return self.mask_other_logits(logits, self.tokenizer.encode(next_char)) + elif last_token in self.token_id_map[TokenState.Number]: + last_number, state, tail_number = self.find_last_complete_number(input_ids) + if state == "tail": + return self.mask_other_logits(logits, self.token_id_map[TokenState.Quote_colon_quote]) + else: + next_str = str(last_number + 1) + next_char = next_str[len(str(tail_number))] + return self.mask_other_logits(logits, self.tokenizer.encode(next_char)) + elif last_token == self.token_id_map[TokenState.Quote_colon_quote][0]: + return self.mask_other_logits(logits, self.token_id_map[TokenState.Main_other]) + elif last_token in self.token_id_map[TokenState.Main_other]: + return self.mask_other_logits(logits, self.token_id_map[TokenState.Quote_comma]) + elif last_token == self.token_id_map[TokenState.Quote_comma][0]: + last_number, _, _ = self.find_last_complete_number(input_ids) + max_count = self.calc_max_count(prompt_token_ids) + if last_number >= max_count: + return self.mask_other_logits(logits, self.token_id_map[TokenState.Right_bracket]) + else: + return self.mask_other_logits(logits, self.token_id_map[TokenState.Space_quote]) + + return logits + + @extractor("llm-webkit") class LlmWebkitExtractor(BaseExtractor): - """Extractor using LLM-WebKit.""" + """Advanced LLM-WebKit extractor with intelligent content classification.""" + + version = "2.0.0" + description = "Advanced LLM-WebKit extractor with intelligent content classification" - version = "1.0.0" - description = "LLM-WebKit based content extractor" + # 分类提示模板 + CLASSIFICATION_PROMPT = """As a front-end engineering expert in HTML, your task is to analyze the given HTML structure and accurately classify elements with the _item_id attribute as either "main" (primary content) or "other" (supplementary content). Your goal is to precisely extract the primary content of the page, ensuring that only the most relevant information is labeled as "main" while excluding navigation, metadata, and other non-essential elements. + +Guidelines for Classification: + +Primary Content ("main") +Elements that constitute the core content of the page should be classified as "main". These typically include: +✅ For Articles, News, and Blogs: +The main text body of the article, blog post, or news content. +Images embedded within the main content that contribute to the article. +✅ For Forums & Discussion Threads: +The original post in the thread. +Replies and discussions that are part of the main conversation. +✅ For Q&A Websites: +The question itself posted by a user. +Answers to the question and replies to answers that contribute to the discussion. +✅ For Other Content-Based Pages: +Any rich text, paragraphs, or media that serve as the primary focus of the page. + +Supplementary Content ("other") +Elements that do not contribute to the primary content but serve as navigation, metadata, or supporting information should be classified as "other". These include: +❌ Navigation & UI Elements: +Menus, sidebars, footers, breadcrumbs, and pagination links. +"Skip to content" links and accessibility-related text. +❌ Metadata & User Information: +Article titles, author names, timestamps, and view counts. +Like counts, vote counts, and other engagement metrics. +❌ Advertisements & Promotional Content: +Any section labeled as "Advertisement" or "Sponsored". +Social media sharing buttons, follow prompts, and external links. +❌ Related & Suggested Content: +"Read More", "Next Article", "Trending Topics", and similar sections. +Lists of related articles, tags, and additional recommendations. + +Task Instructions: +You will be provided with a simplified HTML structure containing elements with an _item_id attribute. Your job is to analyze each element's function and determine whether it should be classified as "main" or "other". + +Response Format: +Return a JSON object where each key is the _item_id value, and the corresponding value is either "main" or "other", as in the following example: +{{"1": "other","2": "main","3": "other"}} + +🚨 Important Notes: +Do not include any explanations in the output—only return the JSON. +Ensure high accuracy by carefully distinguishing between primary content and supplementary content. +Err on the side of caution—if an element seems uncertain, classify it as "other" unless it clearly belongs to the main content. + +Input HTML: +{alg_html} + +Output format should be a JSON-formatted string representing a dictionary where keys are item_id strings and values are either 'main' or 'other'. Make sure to include ALL item_ids from the input HTML.""" + + def __init__(self, name: str, config: Optional[Dict[str, Any]] = None): + super().__init__(name, config) + self.inference_config = LLMInferenceConfig() + self.model = None + self.tokenizer = None + self.token_state_manager = None + + # Override config if provided + if config: + for key, value in config.items(): + if hasattr(self.inference_config, key): + setattr(self.inference_config, key, value) def _setup(self) -> None: - """Setup the LLM-WebKit extractor.""" + """Setup the LLM-WebKit extractor with advanced inference capabilities.""" + # 初始化模块引用 + self._simplify_html = None + self._PreDataJson = None + self._PreDataJsonKey = None + self._MapItemToHtmlTagsParser = None + self._SamplingParams = None + self._model_loaded = False + + # 检查各个依赖模块的可用性 + missing_modules = [] + + # 检查 llm_web_kit try: - # Import llm_web_kit modules - from llm_web_kit.simple import extract_html - from llm_web_kit.libs.html_utils import get_cc_select_html - from lxml.html import fromstring + from llm_web_kit.main_html_parser.simplify_html.simplify_html import simplify_html + from llm_web_kit.input.pre_data_json import PreDataJson, PreDataJsonKey + from llm_web_kit.main_html_parser.parser.tag_mapping import MapItemToHtmlTagsParser - self._extract_html = extract_html - self._get_cc_select_html = get_cc_select_html - self._fromstring = fromstring + self._simplify_html = simplify_html + self._PreDataJson = PreDataJson + self._PreDataJsonKey = PreDataJsonKey + self._MapItemToHtmlTagsParser = MapItemToHtmlTagsParser except ImportError as e: - raise RuntimeError(f"Failed to import llm_web_kit: {e}") + missing_modules.append(f"llm_web_kit: {e}") + + # 检查 transformers(延迟到实际使用时) + self._transformers_available = False + try: + import transformers + self._transformers_available = True + except ImportError as e: + missing_modules.append(f"transformers: {e}") + + # 检查 vllm(延迟到实际使用时) + self._vllm_available = False + try: + import vllm + from vllm import SamplingParams + self._SamplingParams = SamplingParams + self._vllm_available = True + except ImportError as e: + missing_modules.append(f"vllm: {e}") + + # 如果关键模块缺失,提供详细的错误信息 + if missing_modules: + error_msg = "LLM-WebKit extractor requires additional dependencies:\n" + error_msg += "\n".join([f" • {module}" for module in missing_modules]) + error_msg += "\n\nTo install dependencies:\n" + error_msg += " pip install llm_web_kit transformers vllm torch\n" + error_msg += "\nFor CPU-only usage (limited functionality):\n" + error_msg += " pip install llm_web_kit transformers torch --index-url https://download.pytorch.org/whl/cpu" + + raise RuntimeError(error_msg) + + def _load_model(self): + """延迟加载LLM模型和tokenizer.""" + if self._model_loaded: + return + + # 检查依赖是否可用 + if not self._transformers_available: + raise RuntimeError("transformers library is not available. Please install it: pip install transformers") + + import torch + + # 检测运行环境 + is_apple_silicon = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() + has_cuda = torch.cuda.is_available() + + print(f"🔍 检测到运行环境:") + print(f" CUDA: {has_cuda}") + print(f" Apple Silicon (MPS): {is_apple_silicon}") + + # 对于Apple Silicon,优先使用transformers而不是vLLM(避免兼容性问题) + if is_apple_silicon and not has_cuda: + print("🍎 Apple Silicon环境检测到,使用transformers模式以避免vLLM兼容性问题") + self._load_transformers_model() + else: + # 其他环境尝试使用vLLM + if not self._vllm_available: + print("⚠️ vLLM不可用,回退到transformers模式") + self._load_transformers_model() + else: + self._load_vllm_model() + + def _load_transformers_model(self): + """使用transformers加载模型(兼容性更好)""" + try: + from transformers import AutoTokenizer, AutoModelForCausalLM + import torch + + print(f"📦 使用transformers加载模型: {self.inference_config.model_path}") + + # 加载tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + self.inference_config.model_path, + trust_remote_code=True + ) + + # 设置设备 + if torch.cuda.is_available(): + device = "cuda" + torch_dtype = torch.bfloat16 if self.inference_config.dtype == "bfloat16" else torch.float16 + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = "mps" + torch_dtype = torch.float16 # MPS目前不支持bfloat16 + else: + device = "cpu" + torch_dtype = torch.float32 + + print(f"🎯 使用设备: {device}, 数据类型: {torch_dtype}") + + # 加载模型 + self.model = AutoModelForCausalLM.from_pretrained( + self.inference_config.model_path, + trust_remote_code=True, + torch_dtype=torch_dtype, + device_map=device if device != "mps" else None # MPS不支持device_map + ) + + if device == "mps": + self.model = self.model.to(device) + + self.model.eval() + + # 标记为transformers模式 + self._use_transformers = True + self._model_loaded = True + + print("✅ transformers模型加载成功!") + + except Exception as e: + raise RuntimeError(f"Failed to load transformers model: {e}") + + def _load_vllm_model(self): + """使用vLLM加载模型(高性能但兼容性要求高)""" + try: + from transformers import AutoTokenizer + from vllm import LLM + + print(f"⚡ 使用vLLM加载模型: {self.inference_config.model_path}") + + # 加载tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + self.inference_config.model_path, + trust_remote_code=True + ) + + # vLLM配置 + model_kwargs = { + "model": self.inference_config.model_path, + "trust_remote_code": True, + "dtype": self.inference_config.dtype, + "tensor_parallel_size": self.inference_config.tensor_parallel_size, + "max_model_len": self.inference_config.max_tokens, + "max_num_batched_tokens": max(self.inference_config.max_tokens, 8192), + "gpu_memory_utilization": self.inference_config.gpu_memory_utilization, + "enforce_eager": self.inference_config.enforce_eager, + "disable_custom_all_reduce": True, + "load_format": "auto", + } + + self.model = LLM(**model_kwargs) + + # 初始化token状态管理器 + if self.inference_config.use_logits_processor: + self.token_state_manager = TokenStateManager(self.tokenizer) + + # 标记为vLLM模式 + self._use_transformers = False + self._model_loaded = True + + print("✅ vLLM模型加载成功!") + + except Exception as e: + print(f"⚠️ vLLM加载失败,回退到transformers: {e}") + self._load_transformers_model() + + def _create_prompt(self, simplified_html: str) -> str: + """创建分类提示.""" + return self.CLASSIFICATION_PROMPT.format(alg_html=simplified_html) + + def _add_template(self, prompt: str) -> str: + """添加聊天模板.""" + messages = [ + {"role": "user", "content": prompt} + ] + chat_prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) + return chat_prompt + + def _generate_with_transformers(self, prompt: str) -> str: + """使用transformers生成文本""" + try: + import torch + + # Tokenize输入 + inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.inference_config.max_tokens) + + # 移动到正确的设备 + device = self.model.device + inputs = {k: v.to(device) for k, v in inputs.items()} + + # 生成配置 + generation_config = { + "max_new_tokens": self.inference_config.max_output_tokens, + "temperature": self.inference_config.temperature, + "top_p": self.inference_config.top_p, + "do_sample": self.inference_config.temperature > 0, + "pad_token_id": self.tokenizer.eos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + } + + print(f"🔄 开始生成文本 (max_new_tokens: {generation_config['max_new_tokens']})") + + # 生成 + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + **generation_config + ) + + # 解码输出(只取新生成的部分) + input_length = inputs['input_ids'].shape[1] + generated_ids = outputs[0][input_length:] + generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + + print(f"✅ 生成完成,输出长度: {len(generated_text)}") + print(f"🔍 LLM原始输出: {repr(generated_text[:200])}") # 显示前200字符用于调试 + + # 提取JSON部分 + json_result = self._extract_json_from_text(generated_text) + print(f"🔍 提取的JSON: {repr(json_result[:200])}") # 显示JSON结果 + return json_result + + except Exception as e: + print(f"⚠️ transformers生成失败: {e}") + return "{}" + + def _extract_json_from_text(self, text: str) -> str: + """从生成的文本中提取JSON""" + # 查找JSON部分 + start_idx = text.find("{") + end_idx = text.rfind("}") + 1 + + if start_idx != -1 and end_idx != 0: + json_str = text[start_idx:end_idx] + # 清理JSON + json_str = json_str.strip() + json_str = re.sub(r',\s*}', '}', json_str) + try: + # 验证JSON + json.loads(json_str) + return json_str + except: + pass + + return "{}" + + def _clean_output(self, output) -> str: + """清理LLM输出,提取JSON.""" + prediction = output[0].outputs[0].text + + # 提取JSON + start_idx = prediction.rfind("{") + end_idx = prediction.rfind("}") + 1 + + if start_idx != -1 and end_idx != -1: + json_str = prediction[start_idx:end_idx] + json_str = re.sub(r',\s*}', '}', json_str) # 清理JSON + try: + json.loads(json_str) # 验证 + return json_str + except: + return "{}" + else: + return "{}" + + def _reformat_classification_result(self, json_str: str) -> Dict[str, int]: + """重新格式化分类结果.""" + try: + data = json.loads(json_str) + return {"item_id " + k: 1 if v == "main" else 0 for k, v in data.items()} + except json.JSONDecodeError: + return {} + + def _reconstruct_content(self, original_html: str, classification_result: Dict[str, int], url: str = None) -> tuple: + """根据分类结果重建主要内容.""" + try: + # 按照ray_test_qa.py的正确流程 + # 第一步:使用MapItemToHtmlTagsParser生成main_html + main_html = self._generate_main_html_with_parser(original_html, classification_result) + print(f"🔧 MapItemToHtmlTagsParser生成的main_html长度: {len(main_html)}") + + if not main_html.strip(): + print("⚠️ 没有生成main_html,返回空结果") + return "", [] + + # 第二步:使用llm-webkit的方法将main_html提取成content,传入URL + content, content_list = self._extract_content_from_main_html(main_html, url) + print(f"✅ content提取成功: {len(content)}字符, {len(content_list)}个内容块") + + return content, content_list + + except Exception as e: + print(f"❌ Content reconstruction failed: {e}") + return "", [] + + def _generate_main_html_with_parser(self, original_html: str, classification_result: Dict[str, int]) -> str: + """使用MapItemToHtmlTagsParser生成main_html(按照ray_test_qa.py的流程)""" + try: + # 获取typical_raw_tag_html (简化的HTML) + simplified_html, typical_raw_tag_html, _ = self._simplify_html(original_html) + print(f"🔧 simplified HTML长度: {len(simplified_html)}") + print(f"🔧 typical_raw_tag_html长度: {len(typical_raw_tag_html)}") + + # 按照ray_test_qa.py的流程 + pre_data = self._PreDataJson({}) + pre_data[self._PreDataJsonKey.LLM_RESPONSE] = classification_result + pre_data[self._PreDataJsonKey.TYPICAL_RAW_HTML] = original_html + pre_data[self._PreDataJsonKey.TYPICAL_RAW_TAG_HTML] = typical_raw_tag_html + + print(f"🔧 PreDataJson设置完成,开始解析...") + + # 使用MapItemToHtmlTagsParser解析 + parser = self._MapItemToHtmlTagsParser({}) + pre_data = parser.parse_single(pre_data) + + # 获取生成的main_html + main_html = pre_data.get(self._PreDataJsonKey.TYPICAL_MAIN_HTML, "") + + print(f"✅ MapItemToHtmlTagsParser完成,main_html长度: {len(main_html)}") + return main_html + + except Exception as e: + print(f"❌ MapItemToHtmlTagsParser失败: {e}") + return "" + + def _extract_content_from_main_html(self, main_html: str, url: str = None) -> tuple: + """使用llm-webkit的方法将main_html提取成content""" + try: + from llm_web_kit.simple import extract_html_to_md + import traceback + + print(f"🔧 开始使用llm-webkit简单接口提取content...") + + # 使用简单接口提取markdown,传入URL + content = extract_html_to_md(url or "", main_html, clip_html=False) + + print(f"✅ llm-webkit提取完成: {len(content)}字符") + + # 暂不构建content_list,直接返回空列表 + return content.strip(), [] + + except Exception as e: + print(f"❌ llm-webkit提取失败: {e}") + print(f"❌ 错误详情: {traceback.format_exc()}") + return "", [] + def _extract_content(self, html: str, url: str = None) -> ExtractionResult: """ - Extract content using LLM-WebKit. + 使用高级LLM推理提取内容. Args: - html: HTML content to extract from - url: Optional URL of the page + html: HTML内容 + url: 可选的页面URL Returns: - ExtractionResult instance + ExtractionResult实例 """ + start_time = time.time() + try: - # Use llm_web_kit for extraction - result = self._extract_html(html, url=url) - - # Extract additional groundtruth if HTML has cc-select annotations - groundtruth_content = "" - if "cc-select" in html: - try: - element = self._fromstring(html) - cc_selected = self._get_cc_select_html(element) - # Convert selected elements to markdown - # This would need implementation based on your needs - groundtruth_content = self._element_to_markdown(cc_selected) - except Exception as e: - print(f"Warning: Failed to extract cc-select content: {e}") - - return ExtractionResult( - content=result.get('content', ''), - content_list=result.get('content_list', []), - title=result.get('title'), - language=result.get('language'), - # confidence_score=self._calculate_confidence(result), + # 步骤1: HTML简化处理 + simplified_html, typical_raw_tag_html, _ = self._simplify_html(html) + + # 步骤2: 检查长度限制 + item_count = simplified_html.count('_item_id') + if item_count > self.inference_config.max_item_count: + return ExtractionResult.create_error_result( + f"HTML too complex: {item_count} items > {self.inference_config.max_item_count} limit" + ) + + if item_count == 0: + return ExtractionResult.create_error_result("No _item_id found in simplified HTML") + + # 步骤3: 延迟加载模型 + self._load_model() + + # 步骤4: 创建提示并进行LLM推理 + prompt = self._create_prompt(simplified_html) + chat_prompt = self._add_template(prompt) + + # 配置采样参数 + if self.inference_config.use_logits_processor and self.token_state_manager: + sampling_params = self._SamplingParams( + temperature=self.inference_config.temperature, + top_p=self.inference_config.top_p, + max_tokens=self.inference_config.max_output_tokens, + logits_processors=[self.token_state_manager.process_logit] + ) + else: + sampling_params = self._SamplingParams( + temperature=self.inference_config.temperature, + top_p=self.inference_config.top_p, + max_tokens=self.inference_config.max_output_tokens + ) + + # 根据模型类型选择生成方式 + if hasattr(self, '_use_transformers') and self._use_transformers: + # 使用transformers生成 + json_result = self._generate_with_transformers(chat_prompt) + else: + # 使用vLLM生成 + output = self.model.generate(chat_prompt, sampling_params) + json_result = self._clean_output(output) + + # 步骤5: 格式转换和内容重建 + print(f"🔄 开始格式转换...") + classification_result = self._reformat_classification_result(json_result) + print(f"🔍 格式转换结果: {len(classification_result)} 个分类项") + + print(f"🔄 开始重建内容...") + main_content, content_list = self._reconstruct_content(html, classification_result, url) + print(f"🔍 重建结果: 主内容长度={len(main_content)}, 内容块数量={len(content_list) if content_list else 0}") + + # 计算置信度 + confidence = self._calculate_confidence(main_content, content_list, item_count) + + extraction_time = time.time() - start_time + + # 创建结果对象 + result = ExtractionResult( + content=main_content, + content_list=content_list, + title=self._extract_title(html), + language=self._detect_language(main_content), + confidence_score=confidence, + extraction_time=extraction_time, success=True ) + # 添加调试信息到错误消息字段(用于开发调试) + debug_info = f"item_count: {item_count}, llm_output_length: {len(json_result)}" + if not result.success: + result.error_message = f"{result.error_message or ''} | {debug_info}".strip(' |') + + return result + except Exception as e: + extraction_time = time.time() - start_time return ExtractionResult.create_error_result( - f"LLM-WebKit extraction failed: {str(e)}" + f"LLM-WebKit extraction failed: {str(e)}", + extraction_time=extraction_time ) - def _element_to_markdown(self, element) -> str: - """Convert HTML element to markdown (placeholder implementation).""" - # This is a placeholder - you would implement actual HTML to markdown conversion + def _extract_title(self, html: str) -> Optional[str]: + """提取页面标题.""" try: - from llm_web_kit.libs.html_utils import element_to_html - html_str = element_to_html(element) - # Here you would convert HTML to markdown - # For now, just return the HTML - return html_str + import re + title_match = re.search(r']*>(.*?)', html, re.IGNORECASE | re.DOTALL) + if title_match: + return title_match.group(1).strip() except: - return "" + pass + return None - def _calculate_confidence(self, result: Dict[str, Any]) -> float: - """Calculate extraction confidence score.""" - # Simple confidence calculation based on content length and structure - content = result.get('content', '') - content_list = result.get('content_list', []) + def _detect_language(self, content: str) -> Optional[str]: + """检测内容语言.""" + if not content: + return None + + # 简单的语言检测逻辑 + chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', content)) + english_chars = len(re.findall(r'[a-zA-Z]', content)) + if chinese_chars > english_chars: + return "zh" + elif english_chars > 0: + return "en" + else: + return None + + def _calculate_confidence(self, content: str, content_list: List[Dict], item_count: int) -> float: + """计算提取置信度.""" if not content: return 0.0 - # Factor in content length - length_score = min(len(content) / 1000, 1.0) # Normalize to 1000 chars + # 基于内容长度的评分 + length_score = min(len(content) / 1000, 1.0) + + # 基于结构化内容的评分 + structure_score = min(len(content_list) / 10, 1.0) if content_list else 0.0 - # Factor in structure (content_list) - structure_score = min(len(content_list) / 10, 1.0) # Normalize to 10 blocks + # 基于处理复杂度的评分(item数量越多,置信度稍微降低) + complexity_penalty = max(0, (item_count - 100) / 900) # 100-1000范围内线性降低 + complexity_score = max(0.5, 1.0 - complexity_penalty) - # Combine scores - confidence = (length_score * 0.7 + structure_score * 0.3) + # 综合评分 + confidence = (length_score * 0.5 + structure_score * 0.3 + complexity_score * 0.2) return min(confidence, 1.0) \ No newline at end of file