Skip to content

Commit 7dc132e

Browse files
committed
feat: implement text-to-video and image-to-video generation nodes with serializers and workflow integration
1 parent ce6f801 commit 7dc132e

File tree

38 files changed

+1912
-31
lines changed

38 files changed

+1912
-31
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .document_extract_node import *
1414
from .form_node import *
1515
from .image_generate_step_node import *
16+
from .image_to_video_step_node import BaseImageToVideoNode
1617
from .image_understand_step_node import *
1718
from .mcp_node import BaseMcpNode
1819
from .question_node import *
@@ -21,6 +22,7 @@
2122
from .speech_to_text_step_node import BaseSpeechToTextNode
2223
from .start_node import *
2324
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
25+
from .text_to_video_step_node.impl.base_text_to_video_node import BaseTextToVideoNode
2426
from .tool_lib_node import *
2527
from .tool_node import *
2628
from .variable_assign_node import BaseVariableAssignNode
@@ -31,7 +33,8 @@
3133
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
3234
BaseDocumentExtractNode,
3335
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
34-
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode,BaseIntentNode]
36+
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseTextToVideoNode, BaseImageToVideoNode,
37+
BaseIntentNode]
3538

3639

3740
def get_node(node_type):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .impl import *
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from django.utils.translation import gettext_lazy as _
6+
from rest_framework import serializers
7+
8+
from application.flow.i_step_node import INode, NodeResult
9+
10+
11+
class ImageToVideoNodeSerializer(serializers.Serializer):
12+
model_id = serializers.CharField(required=True, label=_("Model id"))
13+
14+
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
15+
16+
negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"),
17+
allow_null=True, allow_blank=True, )
18+
# 多轮对话数量
19+
dialogue_number = serializers.IntegerField(required=False, default=0,
20+
label=_("Number of multi-round conversations"))
21+
22+
dialogue_type = serializers.CharField(required=False, default='NODE',
23+
label=_("Conversation storage type"))
24+
25+
is_result = serializers.BooleanField(required=False,
26+
label=_('Whether to return content'))
27+
28+
model_params_setting = serializers.JSONField(required=False, default=dict,
29+
label=_("Model parameter settings"))
30+
31+
first_frame_url = serializers.ListField(required=True, label=_("First frame url"))
32+
last_frame_url = serializers.ListField(required=False, label=_("Last frame url"))
33+
34+
35+
class IImageToVideoNode(INode):
36+
type = 'image-to-video-node'
37+
38+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
39+
return ImageToVideoNodeSerializer
40+
41+
def _run(self):
42+
first_frame_url = self.workflow_manage.get_reference_field(
43+
self.node_params_serializer.data.get('first_frame_url')[0],
44+
self.node_params_serializer.data.get('first_frame_url')[1:])
45+
if first_frame_url is []:
46+
raise ValueError(
47+
_("First frame url cannot be empty"))
48+
last_frame_url = None
49+
if self.node_params_serializer.data.get('last_frame_url') is not None and self.node_params_serializer.data.get(
50+
'last_frame_url') != []:
51+
last_frame_url = self.workflow_manage.get_reference_field(
52+
self.node_params_serializer.data.get('last_frame_url')[0],
53+
self.node_params_serializer.data.get('last_frame_url')[1:])
54+
node_params_data = {k: v for k, v in self.node_params_serializer.data.items()
55+
if k not in ['first_frame_url', 'last_frame_url']}
56+
return self.execute(first_frame_url=first_frame_url, last_frame_url=last_frame_url,
57+
**node_params_data, **self.flow_params_serializer.data)
58+
59+
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
60+
model_params_setting,
61+
chat_record_id,
62+
first_frame_url, last_frame_url,
63+
**kwargs) -> NodeResult:
64+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .base_image_to_video_node import BaseImageToVideoNode
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# coding=utf-8
2+
import base64
3+
from functools import reduce
4+
from typing import List
5+
6+
import requests
7+
from django.db.models import QuerySet
8+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
9+
10+
from application.flow.i_step_node import NodeResult
11+
from application.flow.step_node.image_to_video_step_node.i_image_to_video_node import IImageToVideoNode
12+
from common.utils.common import bytes_to_uploaded_file
13+
from knowledge.models import FileSourceType, File
14+
from oss.serializers.file import FileSerializer, mime_types
15+
from models_provider.tools import get_model_instance_by_model_workspace_id
16+
17+
18+
class BaseImageToVideoNode(IImageToVideoNode):
19+
def save_context(self, details, workflow_manage):
20+
self.context['answer'] = details.get('answer')
21+
self.context['question'] = details.get('question')
22+
if self.node_params.get('is_result', False):
23+
self.answer_text = details.get('answer')
24+
25+
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
26+
model_params_setting,
27+
chat_record_id,
28+
first_frame_url, last_frame_url=None,
29+
**kwargs) -> NodeResult:
30+
application = self.workflow_manage.work_flow_post_handler.chat_info.application
31+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
32+
ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
33+
**model_params_setting)
34+
history_message = self.get_history_message(history_chat_record, dialogue_number)
35+
self.context['history_message'] = history_message
36+
question = self.generate_prompt_question(prompt)
37+
self.context['question'] = question
38+
message_list = self.generate_message_list(question, history_message)
39+
self.context['message_list'] = message_list
40+
self.context['dialogue_type'] = dialogue_type
41+
self.context['negative_prompt'] = negative_prompt
42+
self.context['first_frame_url'] = first_frame_url
43+
self.context['last_frame_url'] = last_frame_url
44+
# 处理首尾帧图片 这块可以是url 也可以是file_id 如果是url 可以直接传递给模型 如果是file_id 需要传base64
45+
# 判断是不是 url
46+
first_frame_url = self.get_file_base64(first_frame_url)
47+
last_frame_url = self.get_file_base64(last_frame_url)
48+
video_urls = ttv_model.generate_video(question, negative_prompt, first_frame_url, last_frame_url)
49+
# 保存图片
50+
if video_urls is None:
51+
return NodeResult({'answer': '生成视频失败'}, {})
52+
file_name = 'generated_video.mp4'
53+
if isinstance(video_urls, str) and video_urls.startswith('http'):
54+
video_urls = requests.get(video_urls).content
55+
file = bytes_to_uploaded_file(video_urls, file_name)
56+
meta = {
57+
'debug': False if application.id else True,
58+
'chat_id': chat_id,
59+
'application_id': str(application.id) if application.id else None,
60+
}
61+
file_url = FileSerializer(data={
62+
'file': file,
63+
'meta': meta,
64+
'source_id': meta['application_id'],
65+
'source_type': FileSourceType.APPLICATION.value
66+
}).upload()
67+
video_label = f'<video src="{file_url}" controls style="max-width: 100%; width: 100%; height: auto; max-height: 60vh;"></video>'
68+
video_list = [{'file_id': file_url.split('/')[-1], 'file_name': file_name, 'url': file_url}]
69+
return NodeResult({'answer': video_label, 'chat_model': ttv_model, 'message_list': message_list,
70+
'video': video_list,
71+
'history_message': history_message, 'question': question}, {})
72+
73+
def get_file_base64(self, image_url):
74+
if isinstance(image_url, list):
75+
image_url = image_url[0].get('file_id')
76+
if isinstance(image_url, str) and not image_url.startswith('http'):
77+
file = QuerySet(File).filter(id=image_url).first()
78+
file_bytes = file.get_bytes()
79+
# 如果我不知道content_type 可以用 magic 库去检测
80+
file_type = file.file_name.split(".")[-1].lower()
81+
content_type = mime_types.get(file_type, 'application/octet-stream')
82+
encoded_bytes = base64.b64encode(file_bytes)
83+
return f'data:{content_type};base64,{encoded_bytes.decode()}'
84+
return image_url
85+
86+
def generate_history_ai_message(self, chat_record):
87+
for val in chat_record.details.values():
88+
if self.node.id == val['node_id'] and 'image_list' in val:
89+
if val['dialogue_type'] == 'WORKFLOW':
90+
return chat_record.get_ai_message()
91+
image_list = val['image_list']
92+
return AIMessage(content=[
93+
*[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list]
94+
])
95+
return chat_record.get_ai_message()
96+
97+
def get_history_message(self, history_chat_record, dialogue_number):
98+
start_index = len(history_chat_record) - dialogue_number
99+
history_message = reduce(lambda x, y: [*x, *y], [
100+
[self.generate_history_human_message(history_chat_record[index]),
101+
self.generate_history_ai_message(history_chat_record[index])]
102+
for index in
103+
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
104+
return history_message
105+
106+
def generate_history_human_message(self, chat_record):
107+
108+
for data in chat_record.details.values():
109+
if self.node.id == data['node_id'] and 'image_list' in data:
110+
image_list = data['image_list']
111+
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
112+
return HumanMessage(content=chat_record.problem_text)
113+
return HumanMessage(content=data['question'])
114+
return HumanMessage(content=chat_record.problem_text)
115+
116+
def generate_prompt_question(self, prompt):
117+
return self.workflow_manage.generate_prompt(prompt)
118+
119+
def generate_message_list(self, question: str, history_message):
120+
return [
121+
*history_message,
122+
question
123+
]
124+
125+
@staticmethod
126+
def reset_message_list(message_list: List[BaseMessage], answer_text):
127+
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
128+
message
129+
in
130+
message_list]
131+
result.append({'role': 'ai', 'content': answer_text})
132+
return result
133+
134+
def get_details(self, index: int, **kwargs):
135+
return {
136+
'name': self.node.properties.get('stepName'),
137+
"index": index,
138+
'run_time': self.context.get('run_time'),
139+
'history_message': [{'content': message.content, 'role': message.type} for message in
140+
(self.context.get('history_message') if self.context.get(
141+
'history_message') is not None else [])],
142+
'question': self.context.get('question'),
143+
'answer': self.context.get('answer'),
144+
'type': self.node.type,
145+
'message_tokens': self.context.get('message_tokens'),
146+
'answer_tokens': self.context.get('answer_tokens'),
147+
'status': self.status,
148+
'err_message': self.err_message,
149+
'first_frame_url': self.context.get('first_frame_url'),
150+
'last_frame_url': self.context.get('last_frame_url'),
151+
'dialogue_type': self.context.get('dialogue_type'),
152+
'negative_prompt': self.context.get('negative_prompt'),
153+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .impl import *
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from django.utils.translation import gettext_lazy as _
6+
from rest_framework import serializers
7+
8+
from application.flow.i_step_node import INode, NodeResult
9+
10+
11+
class TextToVideoNodeSerializer(serializers.Serializer):
12+
model_id = serializers.CharField(required=True, label=_("Model id"))
13+
14+
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
15+
16+
negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"),
17+
allow_null=True, allow_blank=True, )
18+
# 多轮对话数量
19+
dialogue_number = serializers.IntegerField(required=False, default=0,
20+
label=_("Number of multi-round conversations"))
21+
22+
dialogue_type = serializers.CharField(required=False, default='NODE',
23+
label=_("Conversation storage type"))
24+
25+
is_result = serializers.BooleanField(required=False,
26+
label=_('Whether to return content'))
27+
28+
model_params_setting = serializers.JSONField(required=False, default=dict,
29+
label=_("Model parameter settings"))
30+
31+
32+
class ITextToVideoNode(INode):
33+
type = 'text-to-video-node'
34+
35+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
36+
return TextToVideoNodeSerializer
37+
38+
def _run(self):
39+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
40+
41+
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
42+
model_params_setting,
43+
chat_record_id,
44+
**kwargs) -> NodeResult:
45+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .base_text_to_video_node import BaseTextToVideoNode

0 commit comments

Comments
 (0)