Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions apps/application/chat_pipeline/I_base_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
class ParagraphPipelineModel:

def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
is_active: bool, comprehensive_score: float, similarity: float, knowledge_name: str,
document_name: str,
hit_handling_method: str, directly_return_similarity: float, knowledge_type, meta: dict = None):
self.id = _id
self.document_id = document_id
self.knowledge_id = knowledge_id
Expand All @@ -29,11 +30,12 @@ def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str,
self.is_active = is_active
self.comprehensive_score = comprehensive_score
self.similarity = similarity
self.dataset_name = dataset_name
self.knowledge_name = knowledge_name
self.document_name = document_name
self.hit_handling_method = hit_handling_method
self.directly_return_similarity = directly_return_similarity
self.meta = meta
self.knowledge_type = knowledge_type

def to_dict(self):
return {
Expand All @@ -46,8 +48,9 @@ def to_dict(self):
'is_active': self.is_active,
'comprehensive_score': self.comprehensive_score,
'similarity': self.similarity,
'dataset_name': self.dataset_name,
'knowledge_name': self.knowledge_name,
'document_name': self.document_name,
'knowledge_type': self.knowledge_type,
'meta': self.meta,
}

Expand All @@ -57,7 +60,8 @@ def __init__(self):
self.paragraph = {}
self.comprehensive_score = None
self.document_name = None
self.dataset_name = None
self.knowledge_name = None
self.knowledge_type = None
self.hit_handling_method = None
self.directly_return_similarity = 0.9
self.meta = {}
Expand All @@ -76,8 +80,12 @@ def add_paragraph(self, paragraph):
self.paragraph = paragraph
return self

def add_dataset_name(self, dataset_name):
self.dataset_name = dataset_name
def add_knowledge_name(self, knowledge_name):
self.knowledge_name = knowledge_name
return self

def add_knowledge_type(self, knowledge_type):
self.knowledge_type = knowledge_type
return self

def add_document_name(self, document_name):
Expand Down Expand Up @@ -110,8 +118,9 @@ def build(self):
self.paragraph.get('content'), self.paragraph.get('title'),
self.paragraph.get('status'),
self.paragraph.get('is_active'),
self.comprehensive_score, self.similarity, self.dataset_name,
self.comprehensive_score, self.similarity, self.knowledge_name,
self.document_name, self.hit_handling_method, self.directly_return_similarity,
self.knowledge_type,
self.meta)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import HumanMessage, AIMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import AIMessageChunk, SystemMessage
from rest_framework import status

from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
Expand Down Expand Up @@ -196,7 +196,8 @@ def get_details(self, manage, **kwargs):

@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
result = [{'role': 'user' if isinstance(message, HumanMessage) else (
'system' if isinstance(message, SystemMessage) else 'ai'), 'content': message.content} for
message
in
message_list]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineM
.add_paragraph(paragraph)
.add_similarity(find_embedding.get('similarity'))
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
.add_dataset_name(paragraph.get('dataset_name'))
.add_knowledge_name(paragraph.get('knowledge_name'))
.add_knowledge_type(paragraph.get('knowledge_type'))
.add_document_name(paragraph.get('document_name'))
.add_hit_handling_method(paragraph.get('hit_handling_method'))
.add_directly_return_similarity(paragraph.get('directly_return_similarity'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class RerankerStepNodeSerializer(serializers.Serializer):
question_reference_address = serializers.ListField(required=True)
reranker_model_id = serializers.UUIDField(required=True)
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
show_knowledge = serializers.BooleanField(required=True,
label=_("The results are displayed in the knowledge sources"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
Expand All @@ -55,6 +57,6 @@ def _run(self):

reranker_list=reranker_list)

def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,show_knowledge,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ def merge_reranker_list(reranker_list, result=None):
elif isinstance(document, dict):
content = document.get('title', '') + document.get('content', '')
title = document.get("title")
dataset_name = document.get("dataset_name")
document_name = document.get('document_name')
result.append(
Document(page_content=str(document) if len(content) == 0 else content,
metadata={'title': title, 'dataset_name': dataset_name, 'document_name': document_name}))
metadata={'title': title, **document}))
else:
result.append(Document(page_content=str(document), metadata={}))
return result
Expand Down Expand Up @@ -71,17 +69,18 @@ def save_context(self, details, workflow_manage):
self.context['result_list'] = details.get('result_list')
self.context['result'] = details.get('result')

def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
def execute(self, question, reranker_setting, reranker_list, reranker_model_id, show_knowledge,
**kwargs) -> NodeResult:
self.context['show_knowledge'] = show_knowledge
documents = merge_reranker_list(reranker_list)
top_n = reranker_setting.get('top_n', 3)
self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for
document in documents]
self.context['question'] = question
workspace_id = self.workflow_manage.get_body().get('workspace_id')
reranker_model = get_model_instance_by_model_workspace_id(reranker_model_id,
workspace_id,
top_n=top_n)
workspace_id,
top_n=top_n)
result = reranker_model.compress_documents(
documents,
question)
Expand All @@ -93,6 +92,7 @@ def execute(self, question, reranker_setting, reranker_list, reranker_model_id,

def get_details(self, index: int, **kwargs):
return {
'show_knowledge': self.context.get('show_knowledge'),
'name': self.node.properties.get('stepName'),
"index": index,
'document_list': self.context.get('document_list'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class SearchDatasetStepNodeSerializer(serializers.Serializer):

question_reference_address = serializers.ListField(required=True)

show_knowledge = serializers.BooleanField(required=True,
label=_("The results are displayed in the knowledge sources"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)

Expand Down Expand Up @@ -73,7 +76,7 @@ def _run(self):
return self.execute(**self.node_params_serializer.data, question=str(question),
exclude_paragraph_id_list=exclude_paragraph_id_list)

def execute(self, dataset_id_list, dataset_setting, question,
def execute(self, dataset_id_list, dataset_setting, question, show_knowledge,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ def save_context(self, details, workflow_manage):
result])[0:dataset_setting.get('max_paragraph_char_number', 5000)]
self.context['directly_return'] = directly_return

def execute(self, knowledge_id_list, knowledge_setting, question,
def execute(self, knowledge_id_list, knowledge_setting, question, show_knowledge,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
self.context['question'] = question
self.context['show_knowledge'] = show_knowledge
get_knowledge_list_of_authorized = DatabaseModelManage.get_model('get_knowledge_list_of_authorized')
chat_user_type = self.workflow_manage.get_body().get('chat_user_type')
if get_knowledge_list_of_authorized is not None and RoleConstants.CHAT_USER.value.name == chat_user_type:
Expand Down Expand Up @@ -145,6 +146,7 @@ def list_paragraph(embedding_list: List, vector):
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
'show_knowledge': self.context.get('show_knowledge'),
'question': self.context.get('question'),
"index": index,
'run_time': self.context.get('run_time'),
Expand Down
60 changes: 45 additions & 15 deletions apps/application/serializers/application_chat_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ def one(self, debug):
chat_record = self.get_chat_record()
if chat_record is None:
raise AppApiException(500, gettext("Conversation does not exist"))
return ApplicationChatRecordQuerySerializers.reset_chat_record(chat_record)
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
show_source = False
show_exec = False
if application_access_token is not None:
show_exec = application_access_token.show_exec
show_source = application_access_token.show_source
return ApplicationChatRecordQuerySerializers.reset_chat_record(
chat_record, show_source, show_exec)


class ApplicationChatRecordQuerySerializers(serializers.Serializer):
Expand Down Expand Up @@ -103,21 +111,34 @@ def list(self, with_valid=True):
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by(order_by)]

@staticmethod
def reset_chat_record(chat_record):
def reset_chat_record(chat_record, show_source, show_exec):
knowledge_list = []
paragraph_list = []

if 'search_step' in chat_record.details and chat_record.details.get('search_step').get(
'paragraph_list') is not None:
paragraph_list = chat_record.details.get('search_step').get(
'paragraph_list')
knowledge_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'knowledge_id'): row.get(
"knowledge_name")} for
row in
paragraph_list],
{}).items()]

for item in chat_record.details.values():
if item.get('type') == 'search-knowledge-node' and item.get('show_knowledge', False):
paragraph_list = paragraph_list + item.get(
'paragraph_list')

if item.get('type') == 'reranker-node' and item.get('show_knowledge', False):
paragraph_list = paragraph_list + [rl.get('metadata') for rl in item.get('result_list') if
'document_id' in rl.get('metadata') and 'knowledge_id' in rl.get(
'metadata')]
paragraph_list = list({p.get('id'): p for p in paragraph_list}.values())
knowledge_list = knowledge_list + [{'id': knowledge_id, **knowledge} for knowledge_id, knowledge in
reduce(lambda x, y: {**x, **y},
[{row.get(
'knowledge_id'): {'knowledge_name': row.get(
"knowledge_name"),
'knowledge_type': row.get('knowledge_type')}} for
row in
paragraph_list],
{}).items()]

if len(chat_record.improve_paragraph_id_list) > 0:
paragraph_model_list = QuerySet(Paragraph).filter(id__in=chat_record.improve_paragraph_id_list)
if len(paragraph_model_list) < len(chat_record.improve_paragraph_id_list):
Expand All @@ -126,24 +147,33 @@ def reset_chat_record(chat_record):
filter(lambda p_id: paragraph_model_id_list.__contains__(p_id),
chat_record.improve_paragraph_id_list))
chat_record.save()

show_source_dict = {'knowledge_list': knowledge_list,
'paragraph_list': paragraph_list, }
show_exec_dict = {'execution_details': [chat_record.details[key] for key in chat_record.details]}
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get('problem_padding').get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'knowledge_list': knowledge_list,
'paragraph_list': paragraph_list,
'execution_details': [chat_record.details[key] for key in chat_record.details]
**(show_source_dict if show_source else {}),
**(show_exec_dict if show_exec else {})
}

def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
order_by = '-create_time' if self.data.get('order_asc') is None or self.data.get(
'order_asc') else 'create_time'
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
show_source = False
show_exec = False
if application_access_token is not None:
show_exec = application_access_token.show_exec
show_source = application_access_token.show_source
page = page_search(current_page, page_size,
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by(order_by),
post_records_handler=lambda chat_record: self.reset_chat_record(chat_record))
post_records_handler=lambda chat_record: self.reset_chat_record(chat_record, show_source,
show_exec))
return page


Expand Down
17 changes: 16 additions & 1 deletion apps/application/serializers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
from django.utils.translation import gettext_lazy as _

from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
from application.models import Application, ChatRecord, Chat, ApplicationVersion, ChatUserType
from application.models import Application, ChatRecord, Chat, ApplicationVersion, ChatUserType, ApplicationTypeChoices, \
ApplicationKnowledgeMapping
from common.constants.cache_version import Cache_Version
from common.database_model_manage.database_model_manage import DatabaseModelManage
from common.exception.app_exception import ChatException
from knowledge.models import Document
from models_provider.models import Model
from models_provider.tools import get_model_credential

Expand Down Expand Up @@ -72,6 +74,19 @@ def get_application(self):
'-create_time')[0:1].first()
if not application:
raise ChatException(500, _("The application has not been published. Please use it after publishing."))
if application.type == ApplicationTypeChoices.SIMPLE.value:
# 数据集id列表
knowledge_id_list = [str(row.knowledge_id) for row in
QuerySet(ApplicationKnowledgeMapping).filter(
application_id=self.application_id)]

# 需要排除的文档
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
knowledge_id__in=knowledge_id_list,
is_active=False)]
self.knowledge_id_list = knowledge_id_list
self.exclude_document_id_list = exclude_document_id_list
self.application = application
return application

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
SELECT
paragraph.*,
knowledge."name" AS "knowledge_name",
knowledge."type" AS "knowledge_type",
"document"."name" AS "document_name",
"document"."meta" AS "meta",
"document"."hit_handling_method" AS "hit_handling_method",
Expand Down
Loading
Loading