Skip to content

Commit ead6c5c

Browse files
committed
feat: Volcanic Engine Image Model
1 parent d6abad2 commit ead6c5c

File tree

5 files changed

+137
-8
lines changed

5 files changed

+137
-8
lines changed

apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
1313
from dataset.models import File
1414
from setting.models_provider.tools import get_model_instance_by_model_user_id
15+
from imghdr import what
1516

1617

1718
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@@ -59,8 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
5960

6061
def file_id_to_base64(file_id: str):
6162
file = QuerySet(File).filter(id=file_id).first()
62-
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
63-
return base64_image
63+
file_bytes = file.get_byte()
64+
base64_image = base64.b64encode(file_bytes).decode("utf-8")
65+
return [base64_image, what(None, file_bytes.tobytes())]
6466

6567

6668
class BaseImageUnderstandNode(IImageUnderstandNode):
@@ -77,7 +79,7 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
7779
# 处理不正确的参数
7880
if image is None or not isinstance(image, list):
7981
image = []
80-
82+
print(model_params_setting)
8183
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
8284
# 执行详情中的历史消息不需要图片内容
8385
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
@@ -152,7 +154,7 @@ def generate_history_human_message(self, chat_record):
152154
return HumanMessage(
153155
content=[
154156
{'type': 'text', 'text': data['question']},
155-
*[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for
157+
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
156158
base64_image in image_base64_list]
157159
])
158160
return HumanMessage(content=chat_record.problem_text)
@@ -167,8 +169,10 @@ def generate_message_list(self, image_model, system: str, prompt: str, history_m
167169
for img in image:
168170
file_id = img['file_id']
169171
file = QuerySet(File).filter(id=file_id).first()
170-
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
171-
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}})
172+
image_bytes = file.get_byte()
173+
base64_image = base64.b64encode(image_bytes).decode("utf-8")
174+
image_format = what(None, image_bytes.tobytes())
175+
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
172176
messages = [HumanMessage(
173177
content=[
174178
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
class VolcanicEngineImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.95,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=1024,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
28+
29+
class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
30+
api_key = forms.PasswordInputField('API Key', required=True)
31+
api_base = forms.TextInputField('API 域名', required=True)
32+
33+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
34+
raise_exception=False):
35+
model_type_list = provider.get_model_type_list()
36+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
37+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
38+
39+
for key in ['api_key', 'api_base']:
40+
if key not in model_credential:
41+
if raise_exception:
42+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
43+
else:
44+
return False
45+
try:
46+
model = provider.get_model(model_type, model_name, model_credential)
47+
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
48+
for chunk in res:
49+
print(chunk)
50+
except Exception as e:
51+
if isinstance(e, AppApiException):
52+
raise e
53+
if raise_exception:
54+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
55+
else:
56+
return False
57+
return True
58+
59+
def encryption_dict(self, model: Dict[str, object]):
60+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
61+
62+
def get_model_params_setting_form(self, model_name):
63+
return VolcanicEngineImageModelParams()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Dict
2+
3+
from langchain_openai.chat_models import ChatOpenAI
4+
5+
from common.config.tokenizer_manage_config import TokenizerManage
6+
from setting.models_provider.base_model_provider import MaxKBBaseModel
7+
8+
9+
def custom_get_token_ids(text: str):
10+
tokenizer = TokenizerManage.get_tokenizer()
11+
return tokenizer.encode(text)
12+
13+
14+
class VolcanicEngineImage(MaxKBBaseModel, ChatOpenAI):
15+
16+
@staticmethod
17+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
18+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
19+
return VolcanicEngineImage(
20+
model_name=model_name,
21+
openai_api_key=model_credential.get('api_key'),
22+
openai_api_base=model_credential.get('api_base'),
23+
# stream_options={"include_usage": True},
24+
streaming=True,
25+
**optional_params,
26+
)

apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
1515
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
1616
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
17+
from setting.models_provider.impl.volcanic_engine_model_provider.credential.image import \
18+
VolcanicEngineImageModelCredential
1719
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
20+
from setting.models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage
1821
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
1922
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
2023
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
@@ -25,13 +28,19 @@
2528
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
2629
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
2730
volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential()
31+
volcanic_engine_image_model_credential = VolcanicEngineImageModelCredential()
2832

2933
model_info_list = [
3034
ModelInfo('ep-xxxxxxxxxx-yyyy',
3135
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
3236
ModelTypeConst.LLM,
3337
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
3438
),
39+
ModelInfo('ep-xxxxxxxxxx-yyyy',
40+
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
41+
ModelTypeConst.IMAGE,
42+
volcanic_engine_image_model_credential, VolcanicEngineImage
43+
),
3544
ModelInfo('asr',
3645
'',
3746
ModelTypeConst.STT,
@@ -51,8 +60,13 @@
5160
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
5261
OpenAIEmbeddingModel)]
5362

54-
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
55-
model_info_list[0]).build()
63+
model_info_manage = (
64+
ModelInfoManage.builder()
65+
.append_model_info_list(model_info_list)
66+
.append_default_model_info(model_info_list[0])
67+
.append_default_model_info(model_info_list[1])
68+
.build()
69+
)
5670

5771

5872
class VolcanicEngineModelProvider(IModelProvider):

ui/src/workflow/nodes/image-understand/index.vue

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@
2525
<div>
2626
<span>图片理解模型<span class="danger">*</span></span>
2727
</div>
28+
<el-button
29+
:disabled="!form_data.model_id"
30+
type="primary"
31+
link
32+
@click="openAIParamSettingDialog(form_data.model_id)"
33+
@refreshForm="refreshParam"
34+
>
35+
{{ $t('views.application.applicationForm.form.paramSetting') }}
36+
</el-button>
2837
</div>
2938
</template>
3039
<el-select
@@ -183,6 +192,7 @@
183192
</el-form-item>
184193
</el-form>
185194
</el-card>
195+
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
186196
</NodeContainer>
187197
</template>
188198

@@ -197,6 +207,7 @@ import { app } from '@/main'
197207
import useStore from '@/stores'
198208
import NodeCascader from '@/workflow/common/NodeCascader.vue'
199209
import type { FormInstance } from 'element-plus'
210+
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
200211
201212
const { model } = useStore()
202213
@@ -207,6 +218,7 @@ const {
207218
const props = defineProps<{ nodeModel: any }>()
208219
const modelOptions = ref<any>(null)
209220
const providerOptions = ref<Array<Provider>>([])
221+
const AIModeParamSettingDialogRef = ref<InstanceType<typeof AIModeParamSettingDialog>>()
210222
211223
const aiChatNodeFormRef = ref<FormInstance>()
212224
const validate = () => {
@@ -281,6 +293,16 @@ function submitDialog(val: string) {
281293
set(props.nodeModel.properties.node_data, 'prompt', val)
282294
}
283295
296+
const openAIParamSettingDialog = (modelId: string) => {
297+
if (modelId) {
298+
AIModeParamSettingDialogRef.value?.open(modelId, id, form_data.value.model_params_setting)
299+
}
300+
}
301+
302+
function refreshParam(data: any) {
303+
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
304+
}
305+
284306
onMounted(() => {
285307
getModel()
286308
getProvider()

0 commit comments

Comments
 (0)