Skip to content

Commit 3807cf1

Browse files
authored
feat: application chat (#3213)
1 parent 5f10b70 commit 3807cf1

21 files changed

+835
-138
lines changed

apps/application/chat_pipeline/I_base_chat_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212

1313
from rest_framework import serializers
1414

15-
from dataset.models import Paragraph
15+
from knowledge.models import Paragraph
1616

1717

1818
class ParagraphPipelineModel:
1919

20-
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
20+
def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str,
2121
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
2222
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
2323
self.id = _id
2424
self.document_id = document_id
25-
self.dataset_id = dataset_id
25+
self.knowledge_id = knowledge_id
2626
self.content = content
2727
self.title = title
2828
self.status = status,
@@ -39,7 +39,7 @@ def to_dict(self):
3939
return {
4040
'id': self.id,
4141
'document_id': self.document_id,
42-
'dataset_id': self.dataset_id,
42+
'knowledge_id': self.knowledge_id,
4343
'content': self.content,
4444
'title': self.title,
4545
'status': self.status,
@@ -66,7 +66,7 @@ def add_paragraph(self, paragraph):
6666
if isinstance(paragraph, Paragraph):
6767
self.paragraph = {'id': paragraph.id,
6868
'document_id': paragraph.document_id,
69-
'dataset_id': paragraph.dataset_id,
69+
'knowledge_id': paragraph.knowledge_id,
7070
'content': paragraph.content,
7171
'title': paragraph.title,
7272
'status': paragraph.status,
@@ -106,7 +106,7 @@ def add_meta(self, meta: dict):
106106

107107
def build(self):
108108
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
109-
str(self.paragraph.get('dataset_id')),
109+
str(self.paragraph.get('knowledge_id')),
110110
self.paragraph.get('content'), self.paragraph.get('title'),
111111
self.paragraph.get('status'),
112112
self.paragraph.get('is_active'),

apps/application/chat_pipeline/step/chat_step/i_chat_step.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class PostResponseHandler:
4444
@abstractmethod
4545
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
4646
answer_text,
47-
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
47+
manage, step, padding_problem_text: str = None, **kwargs):
4848
pass
4949

5050

@@ -68,8 +68,9 @@ class InstanceSerializer(serializers.Serializer):
6868
label=_("Completion Question"))
6969
# 是否使用流的形式输出
7070
stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
71-
client_id = serializers.CharField(required=True, label=_("Client id"))
72-
client_type = serializers.CharField(required=True, label=_("Client Type"))
71+
chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))
72+
73+
chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
7374
# 未查询到引用分段
7475
no_references_setting = NoReferencesSetting(required=True,
7576
label=_("No reference segment settings"))
@@ -104,6 +105,6 @@ def execute(self, message_list: List[BaseMessage],
104105
user_id: str = None,
105106
paragraph_list=None,
106107
manage: PipelineManage = None,
107-
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
108+
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
108109
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
109110
pass

apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@
2525
from application.chat_pipeline.pipeline_manage import PipelineManage
2626
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
2727
from application.flow.tools import Reasoning
28-
from application.models.application_api_key import ApplicationPublicAccessClient
29-
from common.constants.authentication_type import AuthenticationType
28+
from application.models import ApplicationChatUserStats, ChatUserType
3029
from models_provider.tools import get_model_instance_by_model_user_id
3130

3231

33-
def add_access_num(client_id=None, client_type=None, application_id=None):
34-
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
35-
application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
36-
application_id=application_id)
32+
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
33+
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
34+
chat_user_type) and application_id is not None:
35+
application_public_access_client = (QuerySet(ApplicationChatUserStats).filter(chat_user_id=chat_user_id,
36+
chat_user_type=chat_user_type,
37+
application_id=application_id)
3738
.first())
3839
if application_public_access_client is not None:
3940
application_public_access_client.access_num = application_public_access_client.access_num + 1
@@ -124,11 +125,9 @@ def event_content(response,
124125
request_token = 0
125126
response_token = 0
126127
write_context(step, manage, request_token, response_token, all_text)
127-
asker = manage.context.get('form_data', {}).get('asker', None)
128128
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
129-
all_text, manage, step, padding_problem_text, client_id,
130-
reasoning_content=reasoning_content if reasoning_content_enable else ''
131-
, asker=asker)
129+
all_text, manage, step, padding_problem_text,
130+
reasoning_content=reasoning_content if reasoning_content_enable else '')
132131
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
133132
[], '', True,
134133
request_token, response_token,
@@ -139,10 +138,8 @@ def event_content(response,
139138
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
140139
all_text = 'Exception:' + str(e)
141140
write_context(step, manage, 0, 0, all_text)
142-
asker = manage.context.get('form_data', {}).get('asker', None)
143141
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
144-
all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
145-
asker=asker)
142+
all_text, manage, step, padding_problem_text, reasoning_content='')
146143
add_access_num(client_id, client_type, manage.context.get('application_id'))
147144
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
148145
[], all_text,
@@ -165,7 +162,7 @@ def execute(self, message_list: List[BaseMessage],
165162
manage: PipelineManage = None,
166163
padding_problem_text: str = None,
167164
stream: bool = True,
168-
client_id=None, client_type=None,
165+
chat_user_id=None, chat_user_type=None,
169166
no_references_setting=None,
170167
model_params_setting=None,
171168
model_setting=None,
@@ -175,12 +172,13 @@ def execute(self, message_list: List[BaseMessage],
175172
if stream:
176173
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
177174
paragraph_list,
178-
manage, padding_problem_text, client_id, client_type, no_references_setting,
175+
manage, padding_problem_text, chat_user_id, chat_user_type,
176+
no_references_setting,
179177
model_setting)
180178
else:
181179
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
182180
paragraph_list,
183-
manage, padding_problem_text, client_id, client_type, no_references_setting,
181+
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
184182
model_setting)
185183

186184
def get_details(self, manage, **kwargs):
@@ -235,7 +233,7 @@ def execute_stream(self, message_list: List[BaseMessage],
235233
paragraph_list=None,
236234
manage: PipelineManage = None,
237235
padding_problem_text: str = None,
238-
client_id=None, client_type=None,
236+
chat_user_id=None, chat_user_type=None,
239237
no_references_setting=None,
240238
model_setting=None):
241239
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
@@ -244,7 +242,8 @@ def execute_stream(self, message_list: List[BaseMessage],
244242
r = StreamingHttpResponse(
245243
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
246244
post_response_handler, manage, self, chat_model, message_list, problem_text,
247-
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
245+
padding_problem_text, chat_user_id, chat_user_type, is_ai_chat,
246+
model_setting),
248247
content_type='text/event-stream;charset=utf-8')
249248

250249
r['Cache-Control'] = 'no-cache'

apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
1616
IGenerateHumanMessageStep
1717
from application.models import ChatRecord
18-
from common.util.split_model import flat_map
18+
from common.utils.common import flat_map
1919

2020

2121
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):

apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class InstanceSerializer(serializers.Serializer):
2626
padding_problem_text = serializers.CharField(required=False,
2727
label=_("System completes question text"))
2828
# 需要查询的数据集id列表
29-
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
30-
label=_("Dataset id list"))
29+
knowledge_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
30+
label=_("Dataset id list"))
3131
# 需要排除的文档id
3232
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
3333
label=_("List of document ids to exclude"))
@@ -55,7 +55,7 @@ def _run(self, manage: PipelineManage):
5555
self.context['paragraph_list'] = paragraph_list
5656

5757
@abstractmethod
58-
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
58+
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
5959
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
6060
search_mode: str = None,
6161
user_id=None,
@@ -65,7 +65,7 @@ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_documen
6565
:param similarity: 相关性
6666
:param top_n: 查询多少条
6767
:param problem_text: 用户问题
68-
:param dataset_id_list: 需要查询的数据集id列表
68+
:param knowledge_id_list: 需要查询的数据集id列表
6969
:param exclude_document_id_list: 需要排除的文档id
7070
:param exclude_paragraph_id_list: 需要排除段落id
7171
:param padding_problem_text 补全问题

apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,42 +35,33 @@ def get_model_by_id(_id, user_id):
3535
return model
3636

3737

38-
def get_embedding_id(dataset_id_list):
39-
<<<<<<< Updated upstream:apps/chat/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
40-
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
41-
if len(set([dataset.embedding_model_id for dataset in dataset_list])) > 1:
42-
raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
43-
if len(dataset_list) == 0:
44-
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
45-
return dataset_list[0].embedding_model_id
46-
=======
47-
knowledge_list = QuerySet(Knowledge).filter(id__in=dataset_id_list)
38+
def get_embedding_id(knowledge_id_list):
39+
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
4840
if len(set([knowledge.embedding_mode_id for knowledge in knowledge_list])) > 1:
4941
raise Exception(
5042
_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
5143
if len(knowledge_list) == 0:
5244
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
5345
return knowledge_list[0].embedding_mode_id
54-
>>>>>>> Stashed changes:apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
5546

5647

5748
class BaseSearchDatasetStep(ISearchDatasetStep):
5849

59-
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
50+
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
6051
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
6152
search_mode: str = None,
6253
user_id=None,
6354
**kwargs) -> List[ParagraphPipelineModel]:
64-
if len(dataset_id_list) == 0:
55+
if len(knowledge_id_list) == 0:
6556
return []
6657
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
67-
model_id = get_embedding_id(dataset_id_list)
58+
model_id = get_embedding_id(knowledge_id_list)
6859
model = get_model_by_id(model_id, user_id)
6960
self.context['model_name'] = model.name
7061
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
7162
embedding_value = embedding_model.embed_query(exec_problem_text)
7263
vector = VectorStore.get_embedding_vector()
73-
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
64+
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, exclude_document_id_list,
7465
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
7566
if embedding_list is None:
7667
return []

apps/application/flow/i_step_node.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from rest_framework.exceptions import ValidationError, ErrorDetail
1919

2020
from application.flow.common import Answer, NodeChunk
21-
from application.models import ChatRecord
22-
from application.models import ApplicationChatClientStats
21+
from application.models import ChatRecord, ChatUserType
22+
from application.models import ApplicationChatUserStats
2323
from common.constants.authentication_type import AuthenticationType
2424
from common.field.common import InstanceField
2525

@@ -45,10 +45,10 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):
4545

4646

4747
class WorkFlowPostHandler:
48-
def __init__(self, chat_info, client_id, client_type):
48+
def __init__(self, chat_info, chat_user_id, chat_user_type):
4949
self.chat_info = chat_info
50-
self.client_id = client_id
51-
self.client_type = client_type
50+
self.chat_user_id = chat_user_id
51+
self.chat_user_type = chat_user_type
5252

5353
def handler(self, chat_id,
5454
chat_record_id,
@@ -84,13 +84,13 @@ def handler(self, chat_id,
8484
run_time=time.time() - workflow.context['start_time'],
8585
index=0)
8686
asker = workflow.context.get('asker', None)
87-
self.chat_info.append_chat_record(chat_record, self.client_id, asker)
88-
# 重新设置缓存
89-
chat_cache.set(chat_id,
90-
self.chat_info, timeout=60 * 30)
91-
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
92-
application_public_access_client = (QuerySet(ApplicationChatClientStats)
93-
.filter(client_id=self.client_id,
87+
self.chat_info.append_chat_record(chat_record)
88+
self.chat_info.set_cahce()
89+
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
90+
self.chat_user_type):
91+
application_public_access_client = (QuerySet(ApplicationChatUserStats)
92+
.filter(chat_user_id=self.chat_user_id,
93+
chat_user_type=self.chat_user_type,
9494
application_id=self.chat_info.application.id).first())
9595
if application_public_access_client is not None:
9696
application_public_access_client.access_num = application_public_access_client.access_num + 1

0 commit comments

Comments
 (0)