|
6 | 6 | @date:2025/6/9 11:23 |
7 | 7 | @desc: |
8 | 8 | """ |
9 | | - |
| 9 | +import json |
10 | 10 | from gettext import gettext |
11 | 11 | from typing import List, Dict |
12 | 12 |
|
13 | 13 | import uuid_utils.compat as uuid |
14 | 14 | from django.db.models import QuerySet |
15 | 15 | from django.utils.translation import gettext_lazy as _ |
| 16 | +from langchain_core.messages import HumanMessage, AIMessage |
16 | 17 | from rest_framework import serializers |
17 | 18 |
|
18 | 19 | from application.chat_pipeline.pipeline_manage import PipelineManage |
|
24 | 25 | from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep |
25 | 26 | from application.flow.common import Answer, Workflow |
26 | 27 | from application.flow.i_step_node import WorkFlowPostHandler |
| 28 | +from application.flow.tools import to_stream_response_simple |
27 | 29 | from application.flow.workflow_manage import WorkflowManage |
28 | 30 | from application.models import Application, ApplicationTypeChoices, ApplicationKnowledgeMapping, \ |
29 | 31 | ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat, ApplicationVersion |
|
37 | 39 | from common.utils.common import flat_map |
38 | 40 | from knowledge.models import Document, Paragraph |
39 | 41 | from models_provider.models import Model, Status |
| 42 | +from models_provider.tools import get_model_instance_by_model_workspace_id |
| 43 | + |
| 44 | + |
| 45 | +class ChatMessagesSerializers(serializers.Serializer): |
| 46 | + role = serializers.CharField(required=True, label=_("Role")) |
| 47 | + content = serializers.CharField(required=True, label=_("Content")) |
| 48 | + |
| 49 | + |
| 50 | +class GeneratePromptSerializers(serializers.Serializer): |
| 51 | + prompt = serializers.CharField(required=True, label=_("Prompt template")) |
| 52 | + messages = serializers.ListSerializer(child=ChatMessagesSerializers(), required=True, label=_("Chat context")) |
40 | 53 |
|
| 54 | + def is_valid(self, *, raise_exception=False): |
| 55 | + super().is_valid(raise_exception=True) |
| 56 | + messages = self.data.get("messages") |
| 57 | + |
| 58 | + if len(messages) > 30: |
| 59 | + raise AppApiException(400, _("Too many messages")) |
| 60 | + |
| 61 | + for index in range(len(messages)): |
| 62 | + role = messages[index].get('role') |
| 63 | + if role == 'ai' and index % 2 != 1: |
| 64 | + raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) |
| 65 | + if role == 'user' and index % 2 != 0: |
| 66 | + raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) |
| 67 | + if role not in ['user', 'ai']: |
| 68 | + raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) |
41 | 69 |
|
42 | 70 | class ChatMessageSerializers(serializers.Serializer): |
43 | 71 | message = serializers.CharField(required=True, label=_("User Questions")) |
@@ -113,6 +141,37 @@ def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToRespon |
113 | 141 | }).chat(instance, base_to_response) |
114 | 142 |
|
115 | 143 |
|
| 144 | +class PromptGenerateSerializer(serializers.Serializer): |
| 145 | + workspace_id = serializers.CharField(required=False, label=_('Workspace ID')) |
| 146 | + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model")) |
| 147 | + |
| 148 | + def generate_prompt(self, instance: dict, with_valid=True): |
| 149 | + if with_valid: |
| 150 | + self.is_valid(raise_exception=True) |
| 151 | + GeneratePromptSerializers(data=instance).is_valid(raise_exception=True) |
| 152 | + workspace_id = self.data.get('workspace_id') |
| 153 | + model_id = self.data.get('model_id') |
| 154 | + prompt = instance.get('prompt') |
| 155 | + messages = instance.get('messages') |
| 156 | + |
| 157 | + message = messages[-1]['content'] |
| 158 | + q = prompt.replace("{userInput}", message) |
| 159 | + messages[-1]['content'] = q |
| 160 | + |
| 161 | + model_exist = QuerySet(Model).filter(workspace_id=workspace_id, id=model_id).exists() |
| 162 | + if not model_exist: |
| 163 | + raise Exception(_("model does not exists")) |
| 164 | + |
| 165 | + def process(): |
| 166 | + model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id) |
| 167 | + |
| 168 | + for r in model.stream([HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage( |
| 169 | + content=m.get('content')) for m in messages]): |
| 170 | + yield 'data: ' + json.dumps({'content': r.content}) + '\n\n' |
| 171 | + |
| 172 | + return to_stream_response_simple(process()) |
| 173 | + |
| 174 | + |
116 | 175 | class OpenAIMessage(serializers.Serializer): |
117 | 176 | content = serializers.CharField(required=True, label=_('content')) |
118 | 177 | role = serializers.CharField(required=True, label=_('Role')) |
|
0 commit comments