Skip to content

Commit d0b4966

Browse files
committed
convert nodes_controlnet.py to V3 schema
1 parent bbd6830 commit d0b4966

File tree

1 file changed

+54
-35
lines changed

1 file changed

+54
-35
lines changed

comfy_extras/nodes_controlnet.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,79 @@
11
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
22
import nodes
33
import comfy.utils
4+
from typing_extensions import override
5+
from comfy_api.latest import ComfyExtension, io
46

5-
class SetUnionControlNetType:
7+
class SetUnionControlNetType(io.ComfyNode):
68
@classmethod
7-
def INPUT_TYPES(s):
8-
return {"required": {"control_net": ("CONTROL_NET", ),
9-
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
10-
}}
9+
def define_schema(cls):
10+
return io.Schema(
11+
node_id="SetUnionControlNetType",
12+
category="conditioning/controlnet",
13+
inputs=[
14+
io.ControlNet.Input("control_net"),
15+
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
16+
],
17+
outputs=[
18+
io.ControlNet.Output(),
19+
],
20+
)
1121

12-
CATEGORY = "conditioning/controlnet"
13-
RETURN_TYPES = ("CONTROL_NET",)
14-
15-
FUNCTION = "set_controlnet_type"
16-
17-
def set_controlnet_type(self, control_net, type):
22+
@classmethod
23+
def execute(cls, control_net, type) -> io.NodeOutput:
1824
control_net = control_net.copy()
1925
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
2026
if type_number >= 0:
2127
control_net.set_extra_arg("control_type", [type_number])
2228
else:
2329
control_net.set_extra_arg("control_type", [])
2430

25-
return (control_net,)
31+
return io.NodeOutput(control_net)
2632

27-
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
33+
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
2834
@classmethod
29-
def INPUT_TYPES(s):
30-
return {"required": {"positive": ("CONDITIONING", ),
31-
"negative": ("CONDITIONING", ),
32-
"control_net": ("CONTROL_NET", ),
33-
"vae": ("VAE", ),
34-
"image": ("IMAGE", ),
35-
"mask": ("MASK", ),
36-
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
37-
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
38-
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
39-
}}
40-
41-
FUNCTION = "apply_inpaint_controlnet"
42-
43-
CATEGORY = "conditioning/controlnet"
44-
45-
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
35+
def define_schema(cls):
36+
return io.Schema(
37+
node_id="ControlNetInpaintingAliMamaApply",
38+
category="conditioning/controlnet",
39+
inputs=[
40+
io.Conditioning.Input("positive"),
41+
io.Conditioning.Input("negative"),
42+
io.ControlNet.Input("control_net"),
43+
io.Vae.Input("vae"),
44+
io.Image.Input("image"),
45+
io.Mask.Input("mask"),
46+
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
47+
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
48+
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
49+
],
50+
outputs=[
51+
io.Conditioning.Output(display_name="positive"),
52+
io.Conditioning.Output(display_name="negative"),
53+
],
54+
)
55+
56+
@classmethod
57+
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
4658
extra_concat = []
4759
if control_net.concat_mask:
4860
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
4961
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
5062
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
5163
extra_concat = [mask]
5264

53-
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
65+
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
66+
return io.NodeOutput(result[0], result[1])
67+
5468

69+
class ControlNetExtension(ComfyExtension):
70+
@override
71+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
72+
return [
73+
SetUnionControlNetType,
74+
ControlNetInpaintingAliMamaApply,
75+
]
5576

5677

57-
NODE_CLASS_MAPPINGS = {
58-
"SetUnionControlNetType": SetUnionControlNetType,
59-
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
60-
}
78+
async def comfy_entrypoint() -> ControlNetExtension:
79+
return ControlNetExtension()

0 commit comments

Comments
 (0)