Skip to content

Commit 7dcd1a7

Browse files
committed
feat: add SyncWeb and GenerateRelated APIs for knowledge base synchronization and related generation
1 parent 0ae489a commit 7dcd1a7

File tree

5 files changed

+335
-5
lines changed

5 files changed

+335
-5
lines changed

apps/knowledge/api/knowledge.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from drf_spectacular.utils import OpenApiParameter
33

44
from common.mixins.api_mixin import APIMixin
5-
from common.result import ResultSerializer
5+
from common.result import ResultSerializer, DefaultResultSerializer
6+
from knowledge.serializers.common import GenerateRelatedSerializer
67
from knowledge.serializers.knowledge import KnowledgeBaseCreateRequest, KnowledgeModelSerializer, KnowledgeEditRequest, \
78
KnowledgeWebCreateRequest
89

@@ -206,3 +207,34 @@ def get_parameters():
206207
required=False,
207208
),
208209
]
210+
211+
212+
class SyncWebAPI(APIMixin):
213+
@staticmethod
214+
def get_parameters():
215+
return [
216+
OpenApiParameter(
217+
name="workspace_id",
218+
description="工作空间id",
219+
type=OpenApiTypes.STR,
220+
location='path',
221+
required=True,
222+
),
223+
OpenApiParameter(
224+
name="knowledge_id",
225+
description="知识库id",
226+
type=OpenApiTypes.STR,
227+
location='path',
228+
required=True,
229+
),
230+
]
231+
232+
@staticmethod
233+
def get_response():
234+
return DefaultResultSerializer
235+
236+
237+
class GenerateRelatedAPI(SyncWebAPI):
238+
@staticmethod
239+
def get_request():
240+
return GenerateRelatedSerializer

apps/knowledge/serializers/knowledge.py

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,34 @@
1+
import logging
12
import os
3+
import re
4+
import traceback
25
from functools import reduce
36
from typing import Dict
47

58
import uuid_utils.compat as uuid
9+
from celery_once import AlreadyQueued
10+
from django.core import validators
611
from django.db import transaction, models
712
from django.db.models import QuerySet
13+
from django.db.models.functions import Reverse, Substr
814
from django.utils.translation import gettext_lazy as _
915
from rest_framework import serializers
1016

1117
from common.db.search import native_search, get_dynamics_model, native_page_search
1218
from common.db.sql_execute import select_list
19+
from common.event import ListenerManagement
1320
from common.exception.app_exception import AppApiException
1421
from common.utils.common import valid_license, post, get_file_content
22+
from common.utils.fork import Fork, ChildLink
23+
from common.utils.split_model import get_split_model
1524
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \
16-
ProblemParagraphMapping, ApplicationKnowledgeMapping
17-
from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer
25+
ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State
26+
from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer, \
27+
GenerateRelatedSerializer
1828
from knowledge.serializers.document import DocumentSerializers
1929
from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge
20-
from knowledge.task.sync import sync_web_knowledge
30+
from knowledge.task.generate import generate_related_by_knowledge_id
31+
from knowledge.task.sync import sync_web_knowledge, sync_replace_web_knowledge
2132
from maxkb.conf import PROJECT_DIR
2233

2334

@@ -137,6 +148,35 @@ class Operate(serializers.Serializer):
137148
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
138149
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
139150

151+
def generate_related(self, instance: Dict, with_valid=True):
152+
if with_valid:
153+
self.is_valid(raise_exception=True)
154+
GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True)
155+
knowledge_id = self.data.get('id')
156+
model_id = instance.get("model_id")
157+
prompt = instance.get("prompt")
158+
state_list = instance.get('state_list')
159+
ListenerManagement.update_status(
160+
QuerySet(Document).filter(knowledge_id=knowledge_id),
161+
TaskType.GENERATE_PROBLEM,
162+
State.PENDING
163+
)
164+
ListenerManagement.update_status(
165+
QuerySet(Paragraph).annotate(
166+
reversed_status=Reverse('status'),
167+
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, 1),
168+
).filter(
169+
task_type_status__in=state_list, knowledge_id=knowledge_id
170+
).values('id'),
171+
TaskType.GENERATE_PROBLEM,
172+
State.PENDING
173+
)
174+
ListenerManagement.get_aggregation_document_status_by_knowledge_id(knowledge_id)()
175+
try:
176+
generate_related_by_knowledge_id.delay(knowledge_id, model_id, prompt, state_list)
177+
except AlreadyQueued as e:
178+
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))
179+
140180
def list_application(self, with_valid=True):
141181
if with_valid:
142182
self.is_valid(raise_exception=True)
@@ -340,3 +380,80 @@ def save_web(self, instance: Dict, with_valid=True):
340380
knowledge.save()
341381
sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
342382
return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []}
383+
384+
class SyncWeb(serializers.Serializer):
385+
id = serializers.CharField(required=True, label=_('knowledge id'))
386+
user_id = serializers.UUIDField(required=False, label=_('user id'))
387+
sync_type = serializers.CharField(required=True, label=_('sync type'), validators=[
388+
validators.RegexValidator(regex=re.compile("^replace|complete$"),
389+
message=_('The synchronization type only supports:replace|complete'), code=500)])
390+
391+
def is_valid(self, *, raise_exception=False):
392+
super().is_valid(raise_exception=True)
393+
first = QuerySet(Knowledge).filter(id=self.data.get("id")).first()
394+
if first is None:
395+
raise AppApiException(300, _('id does not exist'))
396+
if first.type != KnowledgeType.WEB:
397+
raise AppApiException(500, _('Synchronization is only supported for web site types'))
398+
399+
def sync(self, with_valid=True):
400+
if with_valid:
401+
self.is_valid(raise_exception=True)
402+
sync_type = self.data.get('sync_type')
403+
knowledge_id = self.data.get('id')
404+
knowledge = QuerySet(Knowledge).get(id=knowledge_id)
405+
self.__getattribute__(sync_type + '_sync')(knowledge)
406+
return True
407+
408+
@staticmethod
409+
def get_sync_handler(knowledge):
410+
def handler(child_link: ChildLink, response: Fork.Response):
411+
if response.status == 200:
412+
try:
413+
document_name = child_link.tag.text if child_link.tag is not None and len(
414+
child_link.tag.text.strip()) > 0 else child_link.url
415+
paragraphs = get_split_model('web.md').parse(response.content)
416+
print(child_link.url.strip())
417+
first = QuerySet(Document).filter(
418+
meta__source_url=child_link.url.strip(),
419+
knowledge=knowledge
420+
).first()
421+
if first is not None:
422+
# 如果存在,使用文档同步
423+
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
424+
else:
425+
# 插入
426+
DocumentSerializers.Create(data={'knowledge_id': knowledge.id}).save(
427+
{'name': document_name, 'paragraphs': paragraphs,
428+
'meta': {'source_url': child_link.url.strip(),
429+
'selector': knowledge.meta.get('selector')},
430+
'type': Knowledge.WEB}, with_valid=True)
431+
except Exception as e:
432+
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
433+
434+
return handler
435+
436+
def replace_sync(self, knowledge):
437+
"""
438+
替换同步
439+
:return:
440+
"""
441+
url = knowledge.meta.get('source_url')
442+
selector = knowledge.meta.get('selector') if 'selector' in knowledge.meta else None
443+
sync_replace_web_knowledge.delay(str(knowledge.id), url, selector)
444+
445+
def complete_sync(self, knowledge):
446+
"""
447+
完整同步 删掉当前数据集下所有的文档,再进行同步
448+
:return:
449+
"""
450+
# 删除关联问题
451+
QuerySet(ProblemParagraphMapping).filter(knowledge=knowledge).delete()
452+
# 删除文档
453+
QuerySet(Document).filter(knowledge=knowledge).delete()
454+
# 删除段落
455+
QuerySet(Paragraph).filter(knowledge=knowledge).delete()
456+
# 删除向量
457+
delete_embedding_by_knowledge(self.data.get('id'))
458+
# 同步
459+
self.replace_sync(knowledge)

apps/knowledge/task/generate.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import logging
2+
import traceback
3+
4+
from celery_once import QueueOnce
5+
from django.db.models import QuerySet
6+
from django.db.models.functions import Reverse, Substr
7+
from django.utils.translation import gettext_lazy as _
8+
from langchain_core.messages import HumanMessage
9+
10+
from common.config.embedding_config import ModelManage
11+
from common.event import ListenerManagement
12+
from common.utils.page_utils import page, page_desc
13+
from knowledge.models import Paragraph, Document, Status, TaskType, State
14+
from knowledge.task.handler import save_problem
15+
from models_provider.models import Model
16+
from models_provider.tools import get_model
17+
from ops import celery_app
18+
19+
max_kb_error = logging.getLogger("max_kb_error")
20+
max_kb = logging.getLogger("max_kb")
21+
22+
23+
def get_llm_model(model_id):
24+
model = QuerySet(Model).filter(id=model_id).first()
25+
return ModelManage.get_model(model_id, lambda _id: get_model(model))
26+
27+
28+
def generate_problem_by_paragraph(paragraph, llm_model, prompt):
29+
try:
30+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
31+
State.STARTED)
32+
res = llm_model.invoke(
33+
[HumanMessage(content=prompt.replace('{data}', paragraph.content).replace('{title}', paragraph.title))])
34+
if (res.content is None) or (len(res.content) == 0):
35+
return
36+
problems = res.content.split('\n')
37+
for problem in problems:
38+
save_problem(paragraph.knowledge_id, paragraph.document_id, paragraph.id, problem)
39+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
40+
State.SUCCESS)
41+
except Exception as e:
42+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
43+
State.FAILURE)
44+
45+
46+
def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False):
47+
def generate_problem(paragraph_list):
48+
for paragraph in paragraph_list:
49+
if is_the_task_interrupted():
50+
return
51+
generate_problem_by_paragraph(paragraph, llm_model, prompt)
52+
post_apply()
53+
54+
return generate_problem
55+
56+
57+
def get_is_the_task_interrupted(document_id):
58+
def is_the_task_interrupted():
59+
document = QuerySet(Document).filter(id=document_id).first()
60+
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
61+
return True
62+
return False
63+
64+
return is_the_task_interrupted
65+
66+
67+
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']},
68+
name='celery:generate_related_by_knowledge')
69+
def generate_related_by_knowledge_id(knowledge_id, model_id, prompt, state_list=None):
70+
document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
71+
for document in document_list:
72+
try:
73+
generate_related_by_document_id.delay(document.id, model_id, prompt, state_list)
74+
except Exception as e:
75+
pass
76+
77+
78+
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
79+
name='celery:generate_related_by_document')
80+
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
81+
if state_list is None:
82+
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
83+
State.REVOKE.value,
84+
State.REVOKED.value, State.IGNORED.value]
85+
try:
86+
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
87+
if is_the_task_interrupted():
88+
return
89+
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
90+
TaskType.GENERATE_PROBLEM,
91+
State.STARTED)
92+
llm_model = get_llm_model(model_id)
93+
94+
# 生成问题函数
95+
generate_problem = get_generate_problem(llm_model, prompt,
96+
ListenerManagement.get_aggregation_document_status(
97+
document_id), is_the_task_interrupted)
98+
query_set = QuerySet(Paragraph).annotate(
99+
reversed_status=Reverse('status'),
100+
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
101+
1),
102+
).filter(task_type_status__in=state_list, document_id=document_id)
103+
page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
104+
except Exception as e:
105+
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
106+
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
107+
document_id=document_id, error=str(e), traceback=traceback.format_exc()))
108+
finally:
109+
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)
110+
max_kb.info(_('End--->Generate problem: {document_id}').format(document_id=document_id))
111+
112+
113+
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']},
114+
name='celery:generate_related_by_paragraph_list')
115+
def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt):
116+
try:
117+
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
118+
if is_the_task_interrupted():
119+
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
120+
TaskType.GENERATE_PROBLEM,
121+
State.REVOKED)
122+
return
123+
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
124+
TaskType.GENERATE_PROBLEM,
125+
State.STARTED)
126+
llm_model = get_llm_model(model_id)
127+
# 生成问题函数
128+
generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status(
129+
document_id))
130+
131+
def is_the_task_interrupted():
132+
document = QuerySet(Document).filter(id=document_id).first()
133+
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
134+
return True
135+
return False
136+
137+
page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted)
138+
finally:
139+
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)

apps/knowledge/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
path('workspace/<str:workspace_id>/knowledge/base', views.KnowledgeBaseView.as_view()),
99
path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()),
1010
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
11+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
12+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),
1113
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
1214
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
1315
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split_pattern', views.DocumentView.SplitPattern.as_view()),

0 commit comments

Comments
 (0)