Skip to content

Commit f3bd8b5

Browse files
committed
convert nodes_controlnet.py to V3 schema
1 parent bbd6830 commit f3bd8b5

File tree

1 file changed

+79
-36
lines changed

1 file changed

+79
-36
lines changed

comfy_extras/nodes_controlnet.py

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,103 @@
11
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
2-
import nodes
32
import comfy.utils
3+
from typing_extensions import override
4+
from comfy_api.latest import ComfyExtension, io
45

5-
class SetUnionControlNetType:
6+
class SetUnionControlNetType(io.ComfyNode):
67
@classmethod
7-
def INPUT_TYPES(s):
8-
return {"required": {"control_net": ("CONTROL_NET", ),
9-
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
10-
}}
8+
def define_schema(cls):
9+
return io.Schema(
10+
node_id="SetUnionControlNetType",
11+
category="conditioning/controlnet",
12+
inputs=[
13+
io.ControlNet.Input("control_net"),
14+
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
15+
],
16+
outputs=[
17+
io.ControlNet.Output(),
18+
],
19+
)
1120

12-
CATEGORY = "conditioning/controlnet"
13-
RETURN_TYPES = ("CONTROL_NET",)
14-
15-
FUNCTION = "set_controlnet_type"
16-
17-
def set_controlnet_type(self, control_net, type):
21+
@classmethod
22+
def execute(cls, control_net, type) -> io.NodeOutput:
1823
control_net = control_net.copy()
1924
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
2025
if type_number >= 0:
2126
control_net.set_extra_arg("control_type", [type_number])
2227
else:
2328
control_net.set_extra_arg("control_type", [])
2429

25-
return (control_net,)
30+
return io.NodeOutput(control_net)
31+
32+
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
33+
@classmethod
34+
def define_schema(cls):
35+
return io.Schema(
36+
node_id="ControlNetInpaintingAliMamaApply",
37+
category="conditioning/controlnet",
38+
inputs=[
39+
io.Conditioning.Input("positive"),
40+
io.Conditioning.Input("negative"),
41+
io.ControlNet.Input("control_net"),
42+
io.Vae.Input("vae"),
43+
io.Image.Input("image"),
44+
io.Mask.Input("mask"),
45+
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
46+
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
47+
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
48+
],
49+
outputs=[
50+
io.Conditioning.Output(display_name="positive"),
51+
io.Conditioning.Output(display_name="negative"),
52+
],
53+
)
2654

27-
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
2855
@classmethod
29-
def INPUT_TYPES(s):
30-
return {"required": {"positive": ("CONDITIONING", ),
31-
"negative": ("CONDITIONING", ),
32-
"control_net": ("CONTROL_NET", ),
33-
"vae": ("VAE", ),
34-
"image": ("IMAGE", ),
35-
"mask": ("MASK", ),
36-
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
37-
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
38-
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
39-
}}
40-
41-
FUNCTION = "apply_inpaint_controlnet"
42-
43-
CATEGORY = "conditioning/controlnet"
44-
45-
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
56+
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
4657
extra_concat = []
4758
if control_net.concat_mask:
4859
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
4960
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
5061
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
5162
extra_concat = [mask]
5263

53-
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
64+
if strength == 0:
65+
return io.NodeOutput(positive, negative)
66+
67+
control_hint = image.movedim(-1, 1)
68+
cnets = {}
69+
70+
out = []
71+
for conditioning in [positive, negative]:
72+
c = []
73+
for t in conditioning:
74+
d = t[1].copy()
75+
76+
prev_cnet = d.get('control', None)
77+
if prev_cnet in cnets:
78+
c_net = cnets[prev_cnet]
79+
else:
80+
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
81+
vae=vae, extra_concat=extra_concat)
82+
c_net.set_previous_controlnet(prev_cnet)
83+
cnets[prev_cnet] = c_net
84+
85+
d['control'] = c_net
86+
d['control_apply_to_uncond'] = False
87+
n = [t[0], d]
88+
c.append(n)
89+
out.append(c)
90+
return io.NodeOutput(out[0], out[1])
91+
5492

93+
class ControlNetExtension(ComfyExtension):
94+
@override
95+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
96+
return [
97+
SetUnionControlNetType,
98+
ControlNetInpaintingAliMamaApply,
99+
]
55100

56101

57-
NODE_CLASS_MAPPINGS = {
58-
"SetUnionControlNetType": SetUnionControlNetType,
59-
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
60-
}
102+
async def comfy_entrypoint() -> ControlNetExtension:
103+
return ControlNetExtension()

0 commit comments

Comments
 (0)