Skip to content

Commit e141fe8

Browse files
committed
convert nodes_hunyuan.py to V3 schema
1 parent c4a8cf6 commit e141fe8

File tree

1 file changed

+139
-95
lines changed

1 file changed

+139
-95
lines changed

comfy_extras/nodes_hunyuan.py

Lines changed: 139 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,54 @@
22
import node_helpers
33
import torch
44
import comfy.model_management
5+
from typing_extensions import override
6+
from comfy_api.latest import ComfyExtension, io
57

68

7-
class CLIPTextEncodeHunyuanDiT:
9+
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
810
@classmethod
9-
def INPUT_TYPES(s):
10-
return {"required": {
11-
"clip": ("CLIP", ),
12-
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
13-
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
14-
}}
15-
RETURN_TYPES = ("CONDITIONING",)
16-
FUNCTION = "encode"
17-
18-
CATEGORY = "advanced/conditioning"
19-
20-
def encode(self, clip, bert, mt5xl):
11+
def define_schema(cls):
12+
return io.Schema(
13+
node_id="CLIPTextEncodeHunyuanDiT",
14+
category="advanced/conditioning",
15+
inputs=[
16+
io.Clip.Input("clip"),
17+
io.String.Input("bert", multiline=True, dynamic_prompts=True),
18+
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
19+
],
20+
outputs=[
21+
io.Conditioning.Output(),
22+
],
23+
)
24+
25+
@classmethod
26+
def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
2127
tokens = clip.tokenize(bert)
2228
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
2329

24-
return (clip.encode_from_tokens_scheduled(tokens), )
30+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
2531

26-
class EmptyHunyuanLatentVideo:
32+
class EmptyHunyuanLatentVideo(io.ComfyNode):
2733
@classmethod
28-
def INPUT_TYPES(s):
29-
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
30-
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
31-
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
32-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
33-
RETURN_TYPES = ("LATENT",)
34-
FUNCTION = "generate"
34+
def define_schema(cls):
35+
return io.Schema(
36+
node_id="EmptyHunyuanLatentVideo",
37+
category="latent/video",
38+
inputs=[
39+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
40+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
41+
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
42+
io.Int.Input("batch_size", default=1, min=1, max=4096),
43+
],
44+
outputs=[
45+
io.Latent.Output(),
46+
],
47+
)
3548

36-
CATEGORY = "latent/video"
37-
38-
def generate(self, width, height, length, batch_size=1):
49+
@classmethod
50+
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
3951
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
40-
return ({"samples":latent}, )
52+
return io.NodeOutput({"samples":latent})
4153

4254
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
4355
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
@@ -50,45 +62,58 @@ def generate(self, width, height, length, batch_size=1):
5062
"<|start_header_id|>assistant<|end_header_id|>\n\n"
5163
)
5264

53-
class TextEncodeHunyuanVideo_ImageToVideo:
65+
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
5466
@classmethod
55-
def INPUT_TYPES(s):
56-
return {"required": {
57-
"clip": ("CLIP", ),
58-
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
59-
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
60-
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
61-
}}
62-
RETURN_TYPES = ("CONDITIONING",)
63-
FUNCTION = "encode"
64-
65-
CATEGORY = "advanced/conditioning"
66-
67-
def encode(self, clip, clip_vision_output, prompt, image_interleave):
67+
def define_schema(cls):
68+
return io.Schema(
69+
node_id="TextEncodeHunyuanVideo_ImageToVideo",
70+
category="advanced/conditioning",
71+
inputs=[
72+
io.Clip.Input("clip"),
73+
io.ClipVisionOutput.Input("clip_vision_output"),
74+
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
75+
io.Int.Input(
76+
"image_interleave",
77+
default=2,
78+
min=1,
79+
max=512,
80+
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
81+
),
82+
],
83+
outputs=[
84+
io.Conditioning.Output(),
85+
],
86+
)
87+
88+
@classmethod
89+
def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
6890
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
69-
return (clip.encode_from_tokens_scheduled(tokens), )
91+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
92+
93+
class HunyuanImageToVideo(io.ComfyNode):
94+
@classmethod
95+
def define_schema(cls):
96+
return io.Schema(
97+
node_id="HunyuanImageToVideo",
98+
category="conditioning/video_models",
99+
inputs=[
100+
io.Conditioning.Input("positive"),
101+
io.Vae.Input("vae"),
102+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
103+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
104+
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
105+
io.Int.Input("batch_size", default=1, min=1, max=4096),
106+
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
107+
io.Image.Input("start_image", optional=True),
108+
],
109+
outputs=[
110+
io.Conditioning.Output(display_name="positive"),
111+
io.Latent.Output(display_name="latent"),
112+
],
113+
)
70114

71-
class HunyuanImageToVideo:
72115
@classmethod
73-
def INPUT_TYPES(s):
74-
return {"required": {"positive": ("CONDITIONING", ),
75-
"vae": ("VAE", ),
76-
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
77-
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
78-
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
79-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
80-
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
81-
},
82-
"optional": {"start_image": ("IMAGE", ),
83-
}}
84-
85-
RETURN_TYPES = ("CONDITIONING", "LATENT")
86-
RETURN_NAMES = ("positive", "latent")
87-
FUNCTION = "encode"
88-
89-
CATEGORY = "conditioning/video_models"
90-
91-
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
116+
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
92117
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
93118
out_latent = {}
94119

@@ -111,51 +136,70 @@ def encode(self, positive, vae, width, height, length, batch_size, guidance_type
111136
positive = node_helpers.conditioning_set_values(positive, cond)
112137

113138
out_latent["samples"] = latent
114-
return (positive, out_latent)
139+
return io.NodeOutput(positive, out_latent)
115140

116-
class EmptyHunyuanImageLatent:
141+
class EmptyHunyuanImageLatent(io.ComfyNode):
117142
@classmethod
118-
def INPUT_TYPES(s):
119-
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
120-
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
121-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
122-
RETURN_TYPES = ("LATENT",)
123-
FUNCTION = "generate"
143+
def define_schema(cls):
144+
return io.Schema(
145+
node_id="EmptyHunyuanImageLatent",
146+
category="latent",
147+
inputs=[
148+
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
149+
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
150+
io.Int.Input("batch_size", default=1, min=1, max=4096),
151+
],
152+
outputs=[
153+
io.Latent.Output(),
154+
],
155+
)
124156

125-
CATEGORY = "latent"
126-
127-
def generate(self, width, height, batch_size=1):
157+
@classmethod
158+
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
128159
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
129-
return ({"samples":latent}, )
160+
return io.NodeOutput({"samples":latent})
130161

131-
class HunyuanRefinerLatent:
162+
class HunyuanRefinerLatent(io.ComfyNode):
132163
@classmethod
133-
def INPUT_TYPES(s):
134-
return {"required": {"positive": ("CONDITIONING", ),
135-
"negative": ("CONDITIONING", ),
136-
"latent": ("LATENT", ),
137-
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
138-
}}
139-
140-
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
141-
RETURN_NAMES = ("positive", "negative", "latent")
142-
143-
FUNCTION = "execute"
164+
def define_schema(cls):
165+
return io.Schema(
166+
node_id="HunyuanRefinerLatent",
167+
inputs=[
168+
io.Conditioning.Input("positive"),
169+
io.Conditioning.Input("negative"),
170+
io.Latent.Input("latent"),
171+
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01),
172+
173+
],
174+
outputs=[
175+
io.Conditioning.Output(display_name="positive"),
176+
io.Conditioning.Output(display_name="negative"),
177+
io.Latent.Output(display_name="latent"),
178+
],
179+
)
144180

145-
def execute(self, positive, negative, latent, noise_augmentation):
181+
@classmethod
182+
def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
146183
latent = latent["samples"]
147184
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
148185
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
149186
out_latent = {}
150187
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
151-
return (positive, negative, out_latent)
188+
return io.NodeOutput(positive, negative, out_latent)
189+
190+
191+
class HunyuanExtension(ComfyExtension):
192+
@override
193+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
194+
return [
195+
CLIPTextEncodeHunyuanDiT,
196+
TextEncodeHunyuanVideo_ImageToVideo,
197+
EmptyHunyuanLatentVideo,
198+
HunyuanImageToVideo,
199+
EmptyHunyuanImageLatent,
200+
HunyuanRefinerLatent,
201+
]
152202

153203

154-
NODE_CLASS_MAPPINGS = {
155-
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
156-
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
157-
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
158-
"HunyuanImageToVideo": HunyuanImageToVideo,
159-
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
160-
"HunyuanRefinerLatent": HunyuanRefinerLatent,
161-
}
204+
async def comfy_entrypoint() -> HunyuanExtension:
205+
return HunyuanExtension()

0 commit comments

Comments
 (0)