Skip to content

Commit a9cf1cd

Browse files
authored
convert nodes_hidream.py to V3 schema (#9946)
1 parent 2555721 commit a9cf1cd

File tree

1 file changed

+53
-35
lines changed

1 file changed

+53
-35
lines changed

comfy_extras/nodes_hidream.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,73 @@
1+
from typing_extensions import override
2+
13
import folder_paths
24
import comfy.sd
35
import comfy.model_management
6+
from comfy_api.latest import ComfyExtension, io
47

58

6-
class QuadrupleCLIPLoader:
9+
class QuadrupleCLIPLoader(io.ComfyNode):
710
@classmethod
8-
def INPUT_TYPES(s):
9-
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
10-
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
11-
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
12-
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
13-
}}
14-
RETURN_TYPES = ("CLIP",)
15-
FUNCTION = "load_clip"
16-
17-
CATEGORY = "advanced/loaders"
18-
19-
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
11+
def define_schema(cls):
12+
return io.Schema(
13+
node_id="QuadrupleCLIPLoader",
14+
category="advanced/loaders",
15+
description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
16+
inputs=[
17+
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
18+
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
19+
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
20+
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
21+
],
22+
outputs=[
23+
io.Clip.Output(),
24+
]
25+
)
2026

21-
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
27+
@classmethod
28+
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
2229
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
2330
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
2431
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
2532
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
2633
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
27-
return (clip,)
34+
return io.NodeOutput(clip)
2835

29-
class CLIPTextEncodeHiDream:
36+
class CLIPTextEncodeHiDream(io.ComfyNode):
3037
@classmethod
31-
def INPUT_TYPES(s):
32-
return {"required": {
33-
"clip": ("CLIP", ),
34-
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
35-
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
36-
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
37-
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
38-
}}
39-
RETURN_TYPES = ("CONDITIONING",)
40-
FUNCTION = "encode"
41-
42-
CATEGORY = "advanced/conditioning"
43-
44-
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
38+
def define_schema(cls):
39+
return io.Schema(
40+
node_id="CLIPTextEncodeHiDream",
41+
category="advanced/conditioning",
42+
inputs=[
43+
io.Clip.Input("clip"),
44+
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
45+
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
46+
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
47+
io.String.Input("llama", multiline=True, dynamic_prompts=True),
48+
],
49+
outputs=[
50+
io.Conditioning.Output(),
51+
]
52+
)
4553

54+
@classmethod
55+
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
4656
tokens = clip.tokenize(clip_g)
4757
tokens["l"] = clip.tokenize(clip_l)["l"]
4858
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
4959
tokens["llama"] = clip.tokenize(llama)["llama"]
50-
return (clip.encode_from_tokens_scheduled(tokens), )
60+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
61+
62+
63+
class HiDreamExtension(ComfyExtension):
64+
@override
65+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
66+
return [
67+
QuadrupleCLIPLoader,
68+
CLIPTextEncodeHiDream,
69+
]
70+
5171

52-
NODE_CLASS_MAPPINGS = {
53-
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
54-
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
55-
}
72+
async def comfy_entrypoint() -> HiDreamExtension:
73+
return HiDreamExtension()

0 commit comments

Comments
 (0)