|
8 | 8 | """ |
9 | 9 |
|
10 | 10 | from gettext import gettext |
11 | | -from typing import List |
| 11 | +from typing import List, Dict |
12 | 12 |
|
13 | 13 | import uuid_utils.compat as uuid |
14 | 14 | from django.db.models import QuerySet |
|
31 | 31 | from application.serializers.common import ChatInfo |
32 | 32 | from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed, ChatException |
33 | 33 | from common.handle.base_to_response import BaseToResponse |
| 34 | +from common.handle.impl.response.openai_to_response import OpenaiToResponse |
34 | 35 | from common.handle.impl.response.system_to_response import SystemToResponse |
35 | 36 | from common.utils.common import flat_map |
36 | 37 | from knowledge.models import Document, Paragraph |
@@ -111,6 +112,66 @@ def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToRespon |
111 | 112 | }).chat(instance, base_to_response) |
112 | 113 |
|
113 | 114 |
|
| 115 | +class OpenAIMessage(serializers.Serializer): |
| 116 | + content = serializers.CharField(required=True, label=_('content')) |
| 117 | + role = serializers.CharField(required=True, label=_('Role')) |
| 118 | + |
| 119 | + |
| 120 | +class OpenAIInstanceSerializer(serializers.Serializer): |
| 121 | + messages = serializers.ListField(child=OpenAIMessage()) |
| 122 | + chat_id = serializers.UUIDField(required=False, label=_("Conversation ID")) |
| 123 | + re_chat = serializers.BooleanField(required=False, label=_("Regenerate")) |
| 124 | + stream = serializers.BooleanField(required=False, label=_("Streaming Output")) |
| 125 | + |
| 126 | + |
| 127 | +class OpenAIChatSerializer(serializers.Serializer): |
| 128 | + application_id = serializers.UUIDField(required=True, label=_("Application ID")) |
| 129 | + chat_user_id = serializers.CharField(required=True, label=_("Client id")) |
| 130 | + chat_user_type = serializers.CharField(required=True, label=_("Client Type")) |
| 131 | + |
| 132 | + @staticmethod |
| 133 | + def get_message(instance): |
| 134 | + return instance.get('messages')[-1].get('content') |
| 135 | + |
| 136 | + @staticmethod |
| 137 | + def generate_chat(chat_id, application_id, message, chat_user_id, chat_user_type): |
| 138 | + if chat_id is None: |
| 139 | + chat_id = str(uuid.uuid1()) |
| 140 | + chat_info = ChatInfo(chat_id, chat_user_id, chat_user_type, [], [], |
| 141 | + application_id) |
| 142 | + chat_info.set_cache() |
| 143 | + return chat_id |
| 144 | + |
| 145 | + def chat(self, instance: Dict, with_valid=True): |
| 146 | + if with_valid: |
| 147 | + self.is_valid(raise_exception=True) |
| 148 | + OpenAIInstanceSerializer(data=instance).is_valid(raise_exception=True) |
| 149 | + chat_id = instance.get('chat_id') |
| 150 | + message = self.get_message(instance) |
| 151 | + re_chat = instance.get('re_chat', False) |
| 152 | + stream = instance.get('stream', False) |
| 153 | + application_id = self.data.get('application_id') |
| 154 | + chat_user_id = self.data.get('chat_user_id') |
| 155 | + chat_user_type = self.data.get('chat_user_type') |
| 156 | + chat_id = self.generate_chat(chat_id, application_id, message, chat_user_id, chat_user_type) |
| 157 | + return ChatSerializers( |
| 158 | + data={ |
| 159 | + 'chat_id': chat_id, |
| 160 | + 'chat_user_id': chat_user_id, |
| 161 | + 'chat_user_type': chat_user_type, |
| 162 | + 'application_id': application_id |
| 163 | + } |
| 164 | + ).chat({'message': message, |
| 165 | + 're_chat': re_chat, |
| 166 | + 'stream': stream, |
| 167 | + 'form_data': instance.get('form_data', {}), |
| 168 | + 'image_list': instance.get('image_list', []), |
| 169 | + 'document_list': instance.get('document_list', []), |
| 170 | + 'audio_list': instance.get('audio_list', []), |
| 171 | + 'other_list': instance.get('other_list', [])}, |
| 172 | + base_to_response=OpenaiToResponse()) |
| 173 | + |
| 174 | + |
114 | 175 | class ChatSerializers(serializers.Serializer): |
115 | 176 | chat_id = serializers.UUIDField(required=True, label=_("Conversation ID")) |
116 | 177 | chat_user_id = serializers.CharField(required=True, label=_("Client id")) |
|
0 commit comments