Skip to content

Commit 8f4ee99

Browse files
authored
convert nodes_morphology.py to V3 schema (#10159)
1 parent 0e9d172 commit 8f4ee99

File tree

1 file changed

+70
-46
lines changed

1 file changed

+70
-46
lines changed

comfy_extras/nodes_morphology.py

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
import torch
22
import comfy.model_management
3+
from typing_extensions import override
4+
from comfy_api.latest import ComfyExtension, io
35

46
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
57
import kornia.color
68

79

8-
class Morphology:
10+
class Morphology(io.ComfyNode):
911
@classmethod
10-
def INPUT_TYPES(s):
11-
return {"required": {"image": ("IMAGE",),
12-
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],),
13-
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}),
14-
}}
12+
def define_schema(cls):
13+
return io.Schema(
14+
node_id="Morphology",
15+
display_name="ImageMorphology",
16+
category="image/postprocessing",
17+
inputs=[
18+
io.Image.Input("image"),
19+
io.Combo.Input(
20+
"operation",
21+
options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],
22+
),
23+
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
24+
],
25+
outputs=[
26+
io.Image.Output(),
27+
],
28+
)
1529

16-
RETURN_TYPES = ("IMAGE",)
17-
FUNCTION = "process"
18-
19-
CATEGORY = "image/postprocessing"
20-
21-
def process(self, image, operation, kernel_size):
30+
@classmethod
31+
def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
2232
device = comfy.model_management.get_torch_device()
2333
kernel = torch.ones(kernel_size, kernel_size, device=device)
2434
image_k = image.to(device).movedim(-1, 1)
@@ -39,49 +49,63 @@ def process(self, image, operation, kernel_size):
3949
else:
4050
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
4151
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
42-
return (img_out,)
52+
return io.NodeOutput(img_out)
4353

4454

45-
class ImageRGBToYUV:
55+
class ImageRGBToYUV(io.ComfyNode):
4656
@classmethod
47-
def INPUT_TYPES(s):
48-
return {"required": { "image": ("IMAGE",),
49-
}}
57+
def define_schema(cls):
58+
return io.Schema(
59+
node_id="ImageRGBToYUV",
60+
category="image/batch",
61+
inputs=[
62+
io.Image.Input("image"),
63+
],
64+
outputs=[
65+
io.Image.Output(display_name="Y"),
66+
io.Image.Output(display_name="U"),
67+
io.Image.Output(display_name="V"),
68+
],
69+
)
5070

51-
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
52-
RETURN_NAMES = ("Y", "U", "V")
53-
FUNCTION = "execute"
54-
55-
CATEGORY = "image/batch"
56-
57-
def execute(self, image):
71+
@classmethod
72+
def execute(cls, image) -> io.NodeOutput:
5873
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
59-
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
74+
return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
6075

61-
class ImageYUVToRGB:
76+
class ImageYUVToRGB(io.ComfyNode):
6277
@classmethod
63-
def INPUT_TYPES(s):
64-
return {"required": {"Y": ("IMAGE",),
65-
"U": ("IMAGE",),
66-
"V": ("IMAGE",),
67-
}}
78+
def define_schema(cls):
79+
return io.Schema(
80+
node_id="ImageYUVToRGB",
81+
category="image/batch",
82+
inputs=[
83+
io.Image.Input("Y"),
84+
io.Image.Input("U"),
85+
io.Image.Input("V"),
86+
],
87+
outputs=[
88+
io.Image.Output(),
89+
],
90+
)
6891

69-
RETURN_TYPES = ("IMAGE",)
70-
FUNCTION = "execute"
71-
72-
CATEGORY = "image/batch"
73-
74-
def execute(self, Y, U, V):
92+
@classmethod
93+
def execute(cls, Y, U, V) -> io.NodeOutput:
7594
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
7695
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
77-
return (out,)
96+
return io.NodeOutput(out)
97+
98+
99+
class MorphologyExtension(ComfyExtension):
100+
@override
101+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
102+
return [
103+
Morphology,
104+
ImageRGBToYUV,
105+
ImageYUVToRGB,
106+
]
107+
78108

79-
NODE_CLASS_MAPPINGS = {
80-
"Morphology": Morphology,
81-
"ImageRGBToYUV": ImageRGBToYUV,
82-
"ImageYUVToRGB": ImageYUVToRGB,
83-
}
109+
async def comfy_entrypoint() -> MorphologyExtension:
110+
return MorphologyExtension()
84111

85-
NODE_DISPLAY_NAME_MAPPINGS = {
86-
"Morphology": "ImageMorphology",
87-
}

0 commit comments

Comments
 (0)