Skip to content

Commit 7eca956

Browse files
authored
convert nodes_photomaker.py to V3 schema (#10017)
1 parent ad5aef2 commit 7eca956

File tree

1 file changed

+46
-28
lines changed

1 file changed

+46
-28
lines changed

comfy_extras/nodes_photomaker.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import comfy.clip_model
55
import comfy.clip_vision
66
import comfy.ops
7+
from typing_extensions import override
8+
from comfy_api.latest import ComfyExtension, io
79

810
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
911
VISION_CONFIG_DICT = {
@@ -116,41 +118,52 @@ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
116118
return updated_prompt_embeds
117119

118120

119-
class PhotoMakerLoader:
121+
class PhotoMakerLoader(io.ComfyNode):
120122
@classmethod
121-
def INPUT_TYPES(s):
122-
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
123-
124-
RETURN_TYPES = ("PHOTOMAKER",)
125-
FUNCTION = "load_photomaker_model"
126-
127-
CATEGORY = "_for_testing/photomaker"
123+
def define_schema(cls):
124+
return io.Schema(
125+
node_id="PhotoMakerLoader",
126+
category="_for_testing/photomaker",
127+
inputs=[
128+
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
129+
],
130+
outputs=[
131+
io.Photomaker.Output(),
132+
],
133+
is_experimental=True,
134+
)
128135

129-
def load_photomaker_model(self, photomaker_model_name):
136+
@classmethod
137+
def execute(cls, photomaker_model_name):
130138
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
131139
photomaker_model = PhotoMakerIDEncoder()
132140
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
133141
if "id_encoder" in data:
134142
data = data["id_encoder"]
135143
photomaker_model.load_state_dict(data)
136-
return (photomaker_model,)
144+
return io.NodeOutput(photomaker_model)
137145

138146

139-
class PhotoMakerEncode:
147+
class PhotoMakerEncode(io.ComfyNode):
140148
@classmethod
141-
def INPUT_TYPES(s):
142-
return {"required": { "photomaker": ("PHOTOMAKER",),
143-
"image": ("IMAGE",),
144-
"clip": ("CLIP", ),
145-
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}),
146-
}}
147-
148-
RETURN_TYPES = ("CONDITIONING",)
149-
FUNCTION = "apply_photomaker"
150-
151-
CATEGORY = "_for_testing/photomaker"
149+
def define_schema(cls):
150+
return io.Schema(
151+
node_id="PhotoMakerEncode",
152+
category="_for_testing/photomaker",
153+
inputs=[
154+
io.Photomaker.Input("photomaker"),
155+
io.Image.Input("image"),
156+
io.Clip.Input("clip"),
157+
io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"),
158+
],
159+
outputs=[
160+
io.Conditioning.Output(),
161+
],
162+
is_experimental=True,
163+
)
152164

153-
def apply_photomaker(self, photomaker, image, clip, text):
165+
@classmethod
166+
def execute(cls, photomaker, image, clip, text):
154167
special_token = "photomaker"
155168
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
156169
try:
@@ -178,11 +191,16 @@ def apply_photomaker(self, photomaker, image, clip, text):
178191
else:
179192
out = cond
180193

181-
return ([[out, {"pooled_output": pooled}]], )
194+
return io.NodeOutput([[out, {"pooled_output": pooled}]])
182195

183196

184-
NODE_CLASS_MAPPINGS = {
185-
"PhotoMakerLoader": PhotoMakerLoader,
186-
"PhotoMakerEncode": PhotoMakerEncode,
187-
}
197+
class PhotomakerExtension(ComfyExtension):
198+
@override
199+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
200+
return [
201+
PhotoMakerLoader,
202+
PhotoMakerEncode,
203+
]
188204

205+
async def comfy_entrypoint() -> PhotomakerExtension:
206+
return PhotomakerExtension()

0 commit comments

Comments
 (0)