Skip to content

Commit 455e530

Browse files
committed
feat: 支持添加图片生成模型(WIP)
1 parent f65546a commit 455e530

File tree

15 files changed

+460
-13
lines changed

15 files changed

+460
-13
lines changed

apps/common/util/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
"""
99
import hashlib
1010
import importlib
11+
import mimetypes
12+
import io
1113
from functools import reduce
1214
from typing import Dict, List
1315

16+
from django.core.files.uploadedfile import InMemoryUploadedFile
1417
from django.db.models import QuerySet
1518

1619
from ..exception.app_exception import AppApiException
@@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000):
111114
batch = data[i:i + batch_size]
112115
model.objects.bulk_create(batch)
113116

117+
118+
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
119+
content_type, _ = mimetypes.guess_type(file_name)
120+
if content_type is None:
121+
# 如果未能识别,设置为默认的二进制文件类型
122+
content_type = "application/octet-stream"
123+
# 创建一个内存中的字节流对象
124+
file_stream = io.BytesIO(file_bytes)
125+
126+
# 获取文件大小
127+
file_size = len(file_bytes)
128+
129+
# 创建 InMemoryUploadedFile 对象
130+
uploaded_file = InMemoryUploadedFile(
131+
file=file_stream,
132+
field_name=None,
133+
name=file_name,
134+
content_type=content_type,
135+
size=file_size,
136+
charset=None,
137+
)
138+
return uploaded_file

apps/setting/models_provider/base_model_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class ModelTypeConst(Enum):
150150
STT = {'code': 'STT', 'message': '语音识别'}
151151
TTS = {'code': 'TTS', 'message': '语音合成'}
152152
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
153+
TTI = {'code': 'TTI', 'message': '图片生成'}
153154
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
154155

155156

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
from abc import abstractmethod
3+
4+
from pydantic import BaseModel
5+
6+
7+
class BaseTextToImage(BaseModel):
8+
@abstractmethod
9+
def check_auth(self):
10+
pass
11+
12+
@abstractmethod
13+
def generate_image(self, prompt: str):
14+
pass
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
14+
class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
15+
api_base = forms.TextInputField('API 域名', required=True)
16+
api_key = forms.PasswordInputField('API Key', required=True)
17+
18+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
19+
raise_exception=False):
20+
model_type_list = provider.get_model_type_list()
21+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
22+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
23+
24+
for key in ['api_base', 'api_key']:
25+
if key not in model_credential:
26+
if raise_exception:
27+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
28+
else:
29+
return False
30+
try:
31+
model = provider.get_model(model_type, model_name, model_credential)
32+
res = model.check_auth()
33+
print(res)
34+
except Exception as e:
35+
if isinstance(e, AppApiException):
36+
raise e
37+
if raise_exception:
38+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
39+
else:
40+
return False
41+
return True
42+
43+
def encryption_dict(self, model: Dict[str, object]):
44+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
45+
46+
def get_model_params_setting_form(self, model_name):
47+
pass
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Dict
2+
3+
import requests
4+
from langchain_core.messages import HumanMessage
5+
from langchain_openai import ChatOpenAI
6+
from openai import OpenAI
7+
8+
from common.config.tokenizer_manage_config import TokenizerManage
9+
from common.util.common import bytes_to_uploaded_file
10+
from dataset.serializers.file_serializers import FileSerializer
11+
from setting.models_provider.base_model_provider import MaxKBBaseModel
12+
from setting.models_provider.impl.base_tti import BaseTextToImage
13+
14+
15+
def custom_get_token_ids(text: str):
16+
tokenizer = TokenizerManage.get_tokenizer()
17+
return tokenizer.encode(text)
18+
19+
20+
class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
21+
api_base: str
22+
api_key: str
23+
model: str
24+
params: dict
25+
26+
def __init__(self, **kwargs):
27+
super().__init__(**kwargs)
28+
self.api_key = kwargs.get('api_key')
29+
self.api_base = kwargs.get('api_base')
30+
self.model = kwargs.get('model')
31+
self.params = kwargs.get('params')
32+
33+
@staticmethod
34+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
35+
optional_params = {'params': {}}
36+
for key, value in model_kwargs.items():
37+
if key not in ['model_id', 'use_local', 'streaming']:
38+
optional_params['params'][key] = value
39+
return OpenAITextToImage(
40+
model=model_name,
41+
api_base=model_credential.get('api_base'),
42+
api_key=model_credential.get('api_key'),
43+
**optional_params,
44+
)
45+
46+
def check_auth(self):
47+
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
48+
response_list = chat.models.with_raw_response.list()
49+
50+
# self.generate_image('生成一个小猫图片')
51+
52+
def generate_image(self, prompt: str):
53+
54+
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
55+
res = chat.images.generate(model='dall-e-3', prompt=prompt)
56+
57+
file_urls = []
58+
for content in res.data:
59+
url = content.url
60+
print(url)
61+
file_name = 'generated_image.png'
62+
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
63+
meta = {'debug': True}
64+
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
65+
file_urls.append(file_url)
66+
print(res)

apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@
1515
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
1616
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
1717
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
18+
from setting.models_provider.impl.openai_model_provider.credential.tti import OpenAITextToImageModelCredential
1819
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
1920
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
2021
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
2122
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
2223
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
24+
from setting.models_provider.impl.openai_model_provider.model.tti import OpenAITextToImage
2325
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
2426
from smartdoc.conf import PROJECT_DIR
2527

2628
openai_llm_model_credential = OpenAILLMModelCredential()
2729
openai_stt_model_credential = OpenAISTTModelCredential()
2830
openai_tts_model_credential = OpenAITTSModelCredential()
2931
openai_image_model_credential = OpenAIImageModelCredential()
32+
openai_tti_model_credential = OpenAITextToImageModelCredential()
3033
model_info_list = [
3134
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
3235
openai_llm_model_credential, OpenAIChatModel
@@ -37,8 +40,8 @@
3740
ModelTypeConst.LLM, openai_llm_model_credential,
3841
OpenAIChatModel),
3942
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新',
40-
ModelTypeConst.LLM, openai_llm_model_credential,
41-
OpenAIChatModel),
43+
ModelTypeConst.LLM, openai_llm_model_credential,
44+
OpenAIChatModel),
4245
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
4346
openai_llm_model_credential,
4447
OpenAIChatModel),
@@ -100,11 +103,27 @@
100103
OpenAIImage),
101104
]
102105

103-
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
104-
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
105-
openai_llm_model_credential, OpenAIChatModel
106-
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
107-
model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build()
106+
model_info_tti_list = [
107+
ModelInfo('dall-e-2', '',
108+
ModelTypeConst.TTI, openai_tti_model_credential,
109+
OpenAITextToImage),
110+
ModelInfo('dall-e-3', '',
111+
ModelTypeConst.TTI, openai_tti_model_credential,
112+
OpenAITextToImage),
113+
]
114+
115+
model_info_manage = (
116+
ModelInfoManage.builder()
117+
.append_model_info_list(model_info_list)
118+
.append_default_model_info(ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
119+
openai_llm_model_credential, OpenAIChatModel
120+
))
121+
.append_model_info_list(model_info_embedding_list)
122+
.append_default_model_info(model_info_embedding_list[0])
123+
.append_model_info_list(model_info_image_list)
124+
.append_model_info_list(model_info_tti_list)
125+
.build()
126+
)
108127

109128

110129
class OpenAIModelProvider(IModelProvider):
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: llm.py
6+
@date:2024/7/11 18:41
7+
@desc:
8+
"""
9+
import base64
10+
import os
11+
from typing import Dict
12+
13+
from langchain_core.messages import HumanMessage
14+
15+
from common import forms
16+
from common.exception.app_exception import AppApiException
17+
from common.forms import BaseForm, TooltipLabel
18+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
19+
20+
21+
class QwenModelParams(BaseForm):
22+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
23+
required=True, default_value=1.0,
24+
_min=0.1,
25+
_max=1.9,
26+
_step=0.01,
27+
precision=2)
28+
29+
max_tokens = forms.SliderField(
30+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
31+
required=True, default_value=800,
32+
_min=1,
33+
_max=100000,
34+
_step=1,
35+
precision=0)
36+
37+
38+
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
39+
40+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
41+
raise_exception=False):
42+
model_type_list = provider.get_model_type_list()
43+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
44+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
45+
for key in ['api_key']:
46+
if key not in model_credential:
47+
if raise_exception:
48+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
49+
else:
50+
return False
51+
try:
52+
model = provider.get_model(model_type, model_name, model_credential)
53+
res = model.check_auth()
54+
print(res)
55+
except Exception as e:
56+
if isinstance(e, AppApiException):
57+
raise e
58+
if raise_exception:
59+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
60+
else:
61+
return False
62+
return True
63+
64+
def encryption_dict(self, model: Dict[str, object]):
65+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
66+
67+
api_key = forms.PasswordInputField('API Key', required=True)
68+
69+
def get_model_params_setting_form(self, model_name):
70+
return QwenModelParams()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# coding=utf-8
2+
from http import HTTPStatus
3+
from pathlib import PurePosixPath
4+
from typing import Dict
5+
from urllib.parse import unquote, urlparse
6+
7+
import requests
8+
from dashscope import ImageSynthesis
9+
from langchain_community.chat_models import ChatTongyi
10+
from langchain_core.messages import HumanMessage
11+
12+
from common.util.common import bytes_to_uploaded_file
13+
from dataset.serializers.file_serializers import FileSerializer
14+
from setting.models_provider.base_model_provider import MaxKBBaseModel
15+
from setting.models_provider.impl.base_tti import BaseTextToImage
16+
17+
18+
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
19+
api_key: str
20+
model_name: str
21+
params: dict
22+
23+
def __init__(self, **kwargs):
24+
super().__init__(**kwargs)
25+
self.api_key = kwargs.get('api_key')
26+
self.model_name = kwargs.get('model_name')
27+
self.params = kwargs.get('params')
28+
29+
@staticmethod
30+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
31+
optional_params = {'params': {}}
32+
for key, value in model_kwargs.items():
33+
if key not in ['model_id', 'use_local', 'streaming']:
34+
optional_params['params'][key] = value
35+
chat_tong_yi = QwenTextToImageModel(
36+
model_name=model_name,
37+
api_key=model_credential.get('api_key'),
38+
**optional_params,
39+
)
40+
return chat_tong_yi
41+
42+
def check_auth(self):
43+
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
44+
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
45+
46+
def generate_image(self, prompt: str):
47+
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
48+
rsp = ImageSynthesis.call(api_key=self.api_key,
49+
model=self.model_name,
50+
prompt=prompt,
51+
n=1,
52+
style='<watercolor>',
53+
size='1024*1024')
54+
file_urls = []
55+
if rsp.status_code == HTTPStatus.OK:
56+
for result in rsp.output.results:
57+
file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1]
58+
file = bytes_to_uploaded_file(requests.get(result.url).content, file_name)
59+
meta = {'debug': True}
60+
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
61+
file_urls.append(file_url)
62+
else:
63+
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
64+
(rsp.status_code, rsp.code, rsp.message))
65+
return file_urls
66+
67+

0 commit comments

Comments
 (0)