Skip to content

Commit cd66d72

Browse files
authored
convert CLIPTextEncodeSDXL nodes to V3 schema (#9716)
1 parent 2103e39 commit cd66d72

File tree

1 file changed

+57
-40
lines changed

1 file changed

+57
-40
lines changed

comfy_extras/nodes_clip_sdxl.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,52 @@
1-
from nodes import MAX_RESOLUTION
1+
from typing_extensions import override
22

3-
class CLIPTextEncodeSDXLRefiner:
3+
import nodes
4+
from comfy_api.latest import ComfyExtension, io
5+
6+
7+
class CLIPTextEncodeSDXLRefiner(io.ComfyNode):
48
@classmethod
5-
def INPUT_TYPES(s):
6-
return {"required": {
7-
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
8-
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
9-
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
10-
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
11-
}}
12-
RETURN_TYPES = ("CONDITIONING",)
13-
FUNCTION = "encode"
14-
15-
CATEGORY = "advanced/conditioning"
16-
17-
def encode(self, clip, ascore, width, height, text):
9+
def define_schema(cls):
10+
return io.Schema(
11+
node_id="CLIPTextEncodeSDXLRefiner",
12+
category="advanced/conditioning",
13+
inputs=[
14+
io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01),
15+
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
16+
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
17+
io.String.Input("text", multiline=True, dynamic_prompts=True),
18+
io.Clip.Input("clip"),
19+
],
20+
outputs=[io.Conditioning.Output()],
21+
)
22+
23+
@classmethod
24+
def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput:
1825
tokens = clip.tokenize(text)
19-
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
26+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}))
2027

21-
class CLIPTextEncodeSDXL:
28+
class CLIPTextEncodeSDXL(io.ComfyNode):
2229
@classmethod
23-
def INPUT_TYPES(s):
24-
return {"required": {
25-
"clip": ("CLIP", ),
26-
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
27-
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
28-
"crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
29-
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
30-
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
31-
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
32-
"text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
33-
"text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
34-
}}
35-
RETURN_TYPES = ("CONDITIONING",)
36-
FUNCTION = "encode"
37-
38-
CATEGORY = "advanced/conditioning"
39-
40-
def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
30+
def define_schema(cls):
31+
return io.Schema(
32+
node_id="CLIPTextEncodeSDXL",
33+
category="advanced/conditioning",
34+
inputs=[
35+
io.Clip.Input("clip"),
36+
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
37+
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
38+
io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION),
39+
io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION),
40+
io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
41+
io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
42+
io.String.Input("text_g", multiline=True, dynamic_prompts=True),
43+
io.String.Input("text_l", multiline=True, dynamic_prompts=True),
44+
],
45+
outputs=[io.Conditioning.Output()],
46+
)
47+
48+
@classmethod
49+
def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput:
4150
tokens = clip.tokenize(text_g)
4251
tokens["l"] = clip.tokenize(text_l)["l"]
4352
if len(tokens["l"]) != len(tokens["g"]):
@@ -46,9 +55,17 @@ def encode(self, clip, width, height, crop_w, crop_h, target_width, target_heigh
4655
tokens["l"] += empty["l"]
4756
while len(tokens["l"]) > len(tokens["g"]):
4857
tokens["g"] += empty["g"]
49-
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
58+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}))
59+
60+
61+
class ClipSdxlExtension(ComfyExtension):
62+
@override
63+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
64+
return [
65+
CLIPTextEncodeSDXLRefiner,
66+
CLIPTextEncodeSDXL,
67+
]
68+
5069

51-
NODE_CLASS_MAPPINGS = {
52-
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
53-
"CLIPTextEncodeSDXL": CLIPTextEncodeSDXL,
54-
}
70+
async def comfy_entrypoint() -> ClipSdxlExtension:
71+
return ClipSdxlExtension()

0 commit comments

Comments
 (0)