77 @desc:
88"""
99import json
10+ import os
1011from gettext import gettext
1112from typing import List , Dict
1213
1314import uuid_utils .compat as uuid
1415from django .db .models import QuerySet
1516from django .utils .translation import gettext_lazy as _
16- from langchain_core .messages import HumanMessage , AIMessage
17+ from langchain_core .messages import HumanMessage , AIMessage , SystemMessage
1718from rest_framework import serializers
1819
1920from application .chat_pipeline .pipeline_manage import PipelineManage
3637from common .handle .base_to_response import BaseToResponse
3738from common .handle .impl .response .openai_to_response import OpenaiToResponse
3839from common .handle .impl .response .system_to_response import SystemToResponse
39- from common .utils .common import flat_map
40+ from common .utils .common import flat_map , get_file_content
4041from knowledge .models import Document , Paragraph
42+ from maxkb .conf import PROJECT_DIR
4143from models_provider .models import Model , Status
4244from models_provider .tools import get_model_instance_by_model_workspace_id
4345
@@ -67,6 +69,7 @@ def is_valid(self, *, raise_exception=False):
6769 if role not in ['user' , 'ai' ]:
6870 raise AppApiException (400 , _ ("Authentication failed. Please verify that the parameters are correct." ))
6971
72+
7073class ChatMessageSerializers (serializers .Serializer ):
7174 message = serializers .CharField (required = True , label = _ ("User Questions" ))
7275 stream = serializers .BooleanField (required = True ,
@@ -140,6 +143,7 @@ def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToRespon
140143 "application_id" : chat_info .application .id , "debug" : True
141144 }).chat (instance , base_to_response )
142145
146+ SYSTEM_ROLE = get_file_content (os .path .join (PROJECT_DIR , "apps" , "chat" , 'template' , 'generate_prompt_system' ))
143147
144148class PromptGenerateSerializer (serializers .Serializer ):
145149 workspace_id = serializers .CharField (required = False , label = _ ('Workspace ID' ))
@@ -152,13 +156,14 @@ def is_valid(self, *, raise_exception=False):
152156 query_set = QuerySet (Application ).filter (id = self .data .get ('application_id' ))
153157 if workspace_id :
154158 query_set = query_set .filter (workspace_id = workspace_id )
155- if not query_set .exists ():
159+ application = query_set .first ()
160+ if application is None :
156161 raise AppApiException (500 , _ ('Application id does not exist' ))
162+ return application
157163
158- def generate_prompt (self , instance : dict , with_valid = True ):
159- if with_valid :
160- self .is_valid (raise_exception = True )
161- GeneratePromptSerializers (data = instance ).is_valid (raise_exception = True )
164+ def generate_prompt (self , instance : dict ):
165+ application = self .is_valid (raise_exception = True )
166+ GeneratePromptSerializers (data = instance ).is_valid (raise_exception = True )
162167 workspace_id = self .data .get ('workspace_id' )
163168 model_id = self .data .get ('model_id' )
164169 prompt = instance .get ('prompt' )
@@ -169,17 +174,19 @@ def generate_prompt(self, instance: dict, with_valid=True):
169174 messages [- 1 ]['content' ] = q
170175
171176 model_exist = QuerySet (Model ).filter (
172- id = model_id ,
173- model_type = "LLM"
174- ).exists ()
177+ id = model_id ,
178+ model_type = "LLM"
179+ ).exists ()
175180 if not model_exist :
176181 raise Exception (_ ("Model does not exists or is not an LLM model" ))
177182
178- def process ():
179- model = get_model_instance_by_model_workspace_id (model_id = model_id , workspace_id = workspace_id )
183+ system_content = SYSTEM_ROLE .format (application_name = application .name , detail = application .desc )
180184
181- for r in model .stream ([HumanMessage (content = m .get ('content' )) if m .get ('role' ) == 'user' else AIMessage (
182- content = m .get ('content' )) for m in messages ]):
185+ def process ():
186+ model = get_model_instance_by_model_workspace_id (model_id = model_id , workspace_id = workspace_id ,** application .model_params_setting )
187+ for r in model .stream ([SystemMessage (content = system_content ),
188+ * [HumanMessage (content = m .get ('content' )) if m .get ('role' ) == 'user' else AIMessage (
189+ content = m .get ('content' )) for m in messages ]]):
183190 yield 'data: ' + json .dumps ({'content' : r .content }) + '\n \n '
184191
185192 return to_stream_response_simple (process ())
0 commit comments