Skip to content

Commit 622780e

Browse files
committed
feat: add video understanding node and related components
1 parent 80c790b commit 622780e

File tree

23 files changed

+922
-71
lines changed

23 files changed

+922
-71
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
from .tool_node import *
3333
from .variable_assign_node import BaseVariableAssignNode
3434
from .variable_splitting_node import BaseVariableSplittingNode
35+
from .video_understand_step_node import BaseVideoUnderstandNode
3536

3637
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseQuestionNode,
3738
BaseConditionNode, BaseReplyNode,
3839
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
3940
BaseDocumentExtractNode,
4041
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
4142
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseTextToVideoNode, BaseImageToVideoNode,
43+
BaseVideoUnderstandNode,
4244
BaseIntentNode, BaseLoopNode, BaseLoopStartStepNode,
4345
BaseLoopContinueNode,
4446
BaseLoopBreakNode, BaseVariableSplittingNode]

apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,16 @@ def generate_message_list(self, image_model, system: str, prompt: str, history_m
169169
# 处理多张图片
170170
images = []
171171
for img in image:
172-
file_id = img['file_id']
173-
file = QuerySet(File).filter(id=file_id).first()
174-
image_bytes = file.get_bytes()
175-
base64_image = base64.b64encode(image_bytes).decode("utf-8")
176-
image_format = what(None, image_bytes)
177-
images.append(
178-
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
172+
if isinstance(img, str) and img.startswith('http'):
173+
images.append({'type': 'image_url', 'image_url': {'url': img}})
174+
else:
175+
file_id = img['file_id']
176+
file = QuerySet(File).filter(id=file_id).first()
177+
image_bytes = file.get_bytes()
178+
base64_image = base64.b64encode(image_bytes).decode("utf-8")
179+
image_format = what(None, image_bytes)
180+
images.append(
181+
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
179182
messages = [HumanMessage(
180183
content=[
181184
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},

apps/application/flow/step_node/start_node/impl/base_start_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def save_context(self, details, workflow_manage):
4848
self.context['document'] = details.get('document_list')
4949
self.context['image'] = details.get('image_list')
5050
self.context['audio'] = details.get('audio_list')
51+
self.context['video'] = details.get('video_list')
5152
self.context['other'] = details.get('other_list')
5253
self.status = details.get('status')
5354
self.err_message = details.get('err_message')
@@ -73,6 +74,7 @@ def execute(self, question, **kwargs) -> NodeResult:
7374
'image': self.workflow_manage.image_list,
7475
'document': self.workflow_manage.document_list,
7576
'audio': self.workflow_manage.audio_list,
77+
'video': self.workflow_manage.video_list,
7678
'other': self.workflow_manage.other_list,
7779

7880
}
@@ -97,6 +99,7 @@ def get_details(self, index: int, **kwargs):
9799
'status': self.status,
98100
'err_message': self.err_message,
99101
'image_list': self.context.get('image'),
102+
'video_list': self.context.get('video'),
100103
'document_list': self.context.get('document'),
101104
'audio_list': self.context.get('audio'),
102105
'other_list': self.context.get('other'),
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .impl import *
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
9+
from django.utils.translation import gettext_lazy as _
10+
11+
12+
class VideoUnderstandNodeSerializer(serializers.Serializer):
13+
model_id = serializers.CharField(required=True, label=_("Model id"))
14+
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
15+
label=_("Role Setting"))
16+
prompt = serializers.CharField(required=True, label=_("Prompt word"))
17+
# 多轮对话数量
18+
dialogue_number = serializers.IntegerField(required=True, label=_("Number of multi-round conversations"))
19+
20+
dialogue_type = serializers.CharField(required=True, label=_("Conversation storage type"))
21+
22+
is_result = serializers.BooleanField(required=False,
23+
label=_('Whether to return content'))
24+
25+
video_list = serializers.ListField(required=False, label=_("video"))
26+
27+
model_params_setting = serializers.JSONField(required=False, default=dict,
28+
label=_("Model parameter settings"))
29+
30+
31+
class IVideoUnderstandNode(INode):
32+
type = 'video-understand-node'
33+
34+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
35+
return VideoUnderstandNodeSerializer
36+
37+
def _run(self):
38+
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('video_list')[0],
39+
self.node_params_serializer.data.get('video_list')[1:])
40+
return self.execute(video=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
41+
42+
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
43+
model_params_setting,
44+
chat_record_id,
45+
video,
46+
**kwargs) -> NodeResult:
47+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .base_video_understand_node import BaseVideoUnderstandNode
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# coding=utf-8
2+
import base64
3+
import time
4+
from functools import reduce
5+
from imghdr import what
6+
from typing import List, Dict
7+
8+
from django.db.models import QuerySet
9+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
10+
11+
from application.flow.i_step_node import NodeResult, INode
12+
from application.flow.step_node.video_understand_step_node.i_video_understand_node import IVideoUnderstandNode
13+
from knowledge.models import File
14+
from models_provider.tools import get_model_instance_by_model_workspace_id
15+
16+
17+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
18+
chat_model = node_variable.get('chat_model')
19+
message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
20+
answer_tokens = chat_model.get_num_tokens(answer)
21+
node.context['message_tokens'] = message_tokens
22+
node.context['answer_tokens'] = answer_tokens
23+
node.context['answer'] = answer
24+
node.context['history_message'] = node_variable['history_message']
25+
node.context['question'] = node_variable['question']
26+
node.context['run_time'] = time.time() - node.context['start_time']
27+
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
28+
node.answer_text = answer
29+
30+
31+
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
32+
"""
33+
写入上下文数据 (流式)
34+
@param node_variable: 节点数据
35+
@param workflow_variable: 全局数据
36+
@param node: 节点
37+
@param workflow: 工作流管理器
38+
"""
39+
response = node_variable.get('result')
40+
answer = ''
41+
for chunk in response:
42+
answer += chunk.content
43+
yield chunk.content
44+
_write_context(node_variable, workflow_variable, node, workflow, answer)
45+
46+
47+
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
48+
"""
49+
写入上下文数据
50+
@param node_variable: 节点数据
51+
@param workflow_variable: 全局数据
52+
@param node: 节点实例对象
53+
@param workflow: 工作流管理器
54+
"""
55+
response = node_variable.get('result')
56+
answer = response.content
57+
_write_context(node_variable, workflow_variable, node, workflow, answer)
58+
59+
60+
def file_id_to_base64(file_id: str):
61+
file = QuerySet(File).filter(id=file_id).first()
62+
file_bytes = file.get_bytes()
63+
base64_video = base64.b64encode(file_bytes).decode("utf-8")
64+
return [base64_video, what(None, file_bytes)]
65+
66+
67+
class BaseVideoUnderstandNode(IVideoUnderstandNode):
68+
def save_context(self, details, workflow_manage):
69+
self.context['answer'] = details.get('answer')
70+
self.context['question'] = details.get('question')
71+
if self.node_params.get('is_result', False):
72+
self.answer_text = details.get('answer')
73+
74+
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
75+
model_params_setting,
76+
chat_record_id,
77+
video,
78+
**kwargs) -> NodeResult:
79+
# 处理不正确的参数
80+
if video is None or not isinstance(video, list):
81+
video = []
82+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
83+
video_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
84+
**model_params_setting)
85+
# 执行详情中的历史消息不需要图片内容
86+
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
87+
self.context['history_message'] = history_message
88+
question = self.generate_prompt_question(prompt)
89+
self.context['question'] = question.content
90+
# 生成消息列表, 真实的history_message
91+
message_list = self.generate_message_list(video_model, system, prompt,
92+
self.get_history_message(history_chat_record, dialogue_number), video)
93+
self.context['message_list'] = message_list
94+
self.context['video_list'] = video
95+
self.context['dialogue_type'] = dialogue_type
96+
if stream:
97+
r = video_model.stream(message_list)
98+
return NodeResult({'result': r, 'chat_model': video_model, 'message_list': message_list,
99+
'history_message': history_message, 'question': question.content}, {},
100+
_write_context=write_context_stream)
101+
else:
102+
r = video_model.invoke(message_list)
103+
return NodeResult({'result': r, 'chat_model': video_model, 'message_list': message_list,
104+
'history_message': history_message, 'question': question.content}, {},
105+
_write_context=write_context)
106+
107+
def get_history_message_for_details(self, history_chat_record, dialogue_number):
108+
start_index = len(history_chat_record) - dialogue_number
109+
history_message = reduce(lambda x, y: [*x, *y], [
110+
[self.generate_history_human_message_for_details(history_chat_record[index]),
111+
self.generate_history_ai_message(history_chat_record[index])]
112+
for index in
113+
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
114+
return history_message
115+
116+
def generate_history_ai_message(self, chat_record):
117+
for val in chat_record.details.values():
118+
if self.node.id == val['node_id'] and 'video_list' in val:
119+
if val['dialogue_type'] == 'WORKFLOW':
120+
return chat_record.get_ai_message()
121+
return AIMessage(content=val['answer'])
122+
return chat_record.get_ai_message()
123+
124+
def generate_history_human_message_for_details(self, chat_record):
125+
for data in chat_record.details.values():
126+
if self.node.id == data['node_id'] and 'video_list' in data:
127+
video_list = data['video_list']
128+
if len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
129+
return HumanMessage(content=chat_record.problem_text)
130+
file_id_list = [video.get('file_id') for video in video_list]
131+
return HumanMessage(content=[
132+
{'type': 'text', 'text': data['question']},
133+
*[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
134+
135+
])
136+
return HumanMessage(content=chat_record.problem_text)
137+
138+
def get_history_message(self, history_chat_record, dialogue_number):
139+
start_index = len(history_chat_record) - dialogue_number
140+
history_message = reduce(lambda x, y: [*x, *y], [
141+
[self.generate_history_human_message(history_chat_record[index]),
142+
self.generate_history_ai_message(history_chat_record[index])]
143+
for index in
144+
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
145+
return history_message
146+
147+
def generate_history_human_message(self, chat_record):
148+
149+
for data in chat_record.details.values():
150+
if self.node.id == data['node_id'] and 'video_list' in data:
151+
video_list = data['video_list']
152+
if len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
153+
return HumanMessage(content=chat_record.problem_text)
154+
video_base64_list = [file_id_to_base64(video.get('file_id')) for video in video_list]
155+
return HumanMessage(
156+
content=[
157+
{'type': 'text', 'text': data['question']},
158+
*[{'type': 'video_url',
159+
'video_url': {'url': f'data:video/{base64_video[1]};base64,{base64_video[0]}'}} for
160+
base64_video in video_base64_list]
161+
])
162+
return HumanMessage(content=chat_record.problem_text)
163+
164+
def generate_prompt_question(self, prompt):
165+
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
166+
167+
def generate_message_list(self, video_model, system: str, prompt: str, history_message, video):
168+
if video is not None and len(video) > 0:
169+
# 处理多张图片
170+
videos = []
171+
for img in video:
172+
if isinstance(img, str) and img.startswith('http'):
173+
videos.append({'type': 'video_url', 'video_url': {'url': img}})
174+
else:
175+
file_id = img['file_id']
176+
file = QuerySet(File).filter(id=file_id).first()
177+
video_bytes = file.get_bytes()
178+
base64_video = base64.b64encode(video_bytes).decode("utf-8")
179+
video_format = what(None, video_bytes)
180+
videos.append(
181+
{'type': 'video_url', 'video_url': {'url': f'data:video/{video_format};base64,{base64_video}'}})
182+
messages = [HumanMessage(
183+
content=[
184+
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
185+
*videos
186+
])]
187+
else:
188+
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
189+
190+
if system is not None and len(system) > 0:
191+
return [
192+
SystemMessage(self.workflow_manage.generate_prompt(system)),
193+
*history_message,
194+
*messages
195+
]
196+
else:
197+
return [
198+
*history_message,
199+
*messages
200+
]
201+
202+
@staticmethod
203+
def reset_message_list(message_list: List[BaseMessage], answer_text):
204+
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
205+
message
206+
in
207+
message_list]
208+
result.append({'role': 'ai', 'content': answer_text})
209+
return result
210+
211+
def get_details(self, index: int, **kwargs):
212+
return {
213+
'name': self.node.properties.get('stepName'),
214+
"index": index,
215+
'run_time': self.context.get('run_time'),
216+
'system': self.node_params.get('system'),
217+
'history_message': [{'content': message.content, 'role': message.type} for message in
218+
(self.context.get('history_message') if self.context.get(
219+
'history_message') is not None else [])],
220+
'question': self.context.get('question'),
221+
'answer': self.context.get('answer'),
222+
'type': self.node.type,
223+
'message_tokens': self.context.get('message_tokens'),
224+
'answer_tokens': self.context.get('answer_tokens'),
225+
'status': self.status,
226+
'err_message': self.err_message,
227+
'video_list': self.context.get('video_list'),
228+
'dialogue_type': self.context.get('dialogue_type')
229+
}

apps/application/flow/workflow_manage.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
9494
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
9595
document_list=None,
9696
audio_list=None,
97+
video_list=None,
9798
other_list=None,
9899
start_node_id=None,
99100
start_node_data=None, chat_record=None, child_node=None):
@@ -105,12 +106,15 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
105106
document_list = []
106107
if audio_list is None:
107108
audio_list = []
109+
if video_list is None:
110+
video_list = []
108111
if other_list is None:
109112
other_list = []
110113
self.start_node_id = start_node_id
111114
self.start_node = None
112115
self.form_data = form_data
113116
self.image_list = image_list
117+
self.video_list = video_list
114118
self.document_list = document_list
115119
self.audio_list = audio_list
116120
self.other_list = other_list

apps/chat/serializers/chat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
375375
chat_user_type = self.data.get('chat_user_type')
376376
form_data = instance.get('form_data')
377377
image_list = instance.get('image_list')
378+
video_list = instance.get('video_list')
378379
document_list = instance.get('document_list')
379380
audio_list = instance.get('audio_list')
380381
other_list = instance.get('other_list')
@@ -401,6 +402,7 @@ def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
401402
'application_id': str(chat_info.application_id)},
402403
WorkFlowPostHandler(chat_info),
403404
base_to_response, form_data, image_list, document_list, audio_list,
405+
video_list,
404406
other_list,
405407
instance.get('runtime_node_id'),
406408
instance.get('node_data'), chat_record, instance.get('child_node'))
Lines changed: 6 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)