Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import json
import random
import re
import uuid

from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -36,6 +38,13 @@ def _filter_docs(split_docs, chunk_size):
return filtered_docs


def extract_img_urls(doc):
"""提取文档中的图片地址"""
pattern = r"!\[\]\((.*?)\)"
# 查找所有匹配的地址
img_urls = re.findall(pattern, doc)
return img_urls

class GenerationService:
def __init__(self, db: AsyncSession):
self.db = db
Expand Down Expand Up @@ -226,6 +235,15 @@ async def _process_single_chunk_qa(

已经进入后续流程的任务(例如其它协程正在生成答案)允许自然执行完。
"""
# 随机决定是否对当前 chunk 进行 QA 生成
if random.random() > question_cfg.temperature:
logger.info(
f"Skip QA generation for chunk_index={chunk.chunk_index} in file_task={file_task.id} due to random decision."
)
# 更新文件任务的 processed_chunks 计数
await self._increment_processed_chunks(file_task.id, 1)
return False

# 如果没有全局上限配置,维持原有行为
if max_qa_pairs is not None and max_qa_pairs > 0:
from sqlalchemy import func
Expand Down Expand Up @@ -411,6 +429,11 @@ async def process_single_question(question: str):
base_obj["instruction"] = question
data_obj = base_obj

# 提取图片URL
img_urls = extract_img_urls(chunk_text)
if img_urls:
data_obj["img_urls"] = img_urls

record = SynthesisData(
id=str(uuid.uuid4()),
data=data_obj,
Expand Down