Skip to content

Commit 117bf3f

Browse files
authored
convert nodes_freelunch.py to the V3 schema (#10904)
1 parent ae676ed commit 117bf3f

File tree

2 files changed

+59
-39
lines changed

2 files changed

+59
-39
lines changed

comfy_extras/nodes_freelunch.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import torch
44
import logging
5+
from typing_extensions import override
6+
from comfy_api.latest import ComfyExtension, IO
57

68
def Fourier_filter(x, threshold, scale):
79
# FFT
@@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
2224
return x_filtered.to(x.dtype)
2325

2426

25-
class FreeU:
27+
class FreeU(IO.ComfyNode):
2628
@classmethod
27-
def INPUT_TYPES(s):
28-
return {"required": { "model": ("MODEL",),
29-
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
30-
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
31-
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
32-
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
33-
}}
34-
RETURN_TYPES = ("MODEL",)
35-
FUNCTION = "patch"
36-
37-
CATEGORY = "model_patches/unet"
38-
39-
def patch(self, model, b1, b2, s1, s2):
29+
def define_schema(cls):
30+
return IO.Schema(
31+
node_id="FreeU",
32+
category="model_patches/unet",
33+
inputs=[
34+
IO.Model.Input("model"),
35+
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
36+
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
37+
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
38+
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
39+
],
40+
outputs=[
41+
IO.Model.Output(),
42+
],
43+
)
44+
45+
@classmethod
46+
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
4047
model_channels = model.model.model_config.unet_config["model_channels"]
4148
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
4249
on_cpu_devices = {}
@@ -59,23 +66,31 @@ def output_block_patch(h, hsp, transformer_options):
5966

6067
m = model.clone()
6168
m.set_model_output_block_patch(output_block_patch)
62-
return (m, )
69+
return IO.NodeOutput(m)
70+
71+
patch = execute # TODO: remove
72+
73+
74+
class FreeU_V2(IO.ComfyNode):
75+
@classmethod
76+
def define_schema(cls):
77+
return IO.Schema(
78+
node_id="FreeU_V2",
79+
category="model_patches/unet",
80+
inputs=[
81+
IO.Model.Input("model"),
82+
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
83+
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
84+
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
85+
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
86+
],
87+
outputs=[
88+
IO.Model.Output(),
89+
],
90+
)
6391

64-
class FreeU_V2:
6592
@classmethod
66-
def INPUT_TYPES(s):
67-
return {"required": { "model": ("MODEL",),
68-
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
69-
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
70-
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
71-
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
72-
}}
73-
RETURN_TYPES = ("MODEL",)
74-
FUNCTION = "patch"
75-
76-
CATEGORY = "model_patches/unet"
77-
78-
def patch(self, model, b1, b2, s1, s2):
93+
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
7994
model_channels = model.model.model_config.unet_config["model_channels"]
8095
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
8196
on_cpu_devices = {}
@@ -105,9 +120,19 @@ def output_block_patch(h, hsp, transformer_options):
105120

106121
m = model.clone()
107122
m.set_model_output_block_patch(output_block_patch)
108-
return (m, )
123+
return IO.NodeOutput(m)
124+
125+
patch = execute # TODO: remove
126+
127+
128+
class FreelunchExtension(ComfyExtension):
129+
@override
130+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
131+
return [
132+
FreeU,
133+
FreeU_V2,
134+
]
135+
109136

110-
NODE_CLASS_MAPPINGS = {
111-
"FreeU": FreeU,
112-
"FreeU_V2": FreeU_V2,
113-
}
137+
async def comfy_entrypoint() -> FreelunchExtension:
138+
return FreelunchExtension()

comfy_extras/nodes_model_downscale.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,6 @@ def output_block_patch(h, hsp, transformer_options):
5353
return io.NodeOutput(m)
5454

5555

56-
NODE_DISPLAY_NAME_MAPPINGS = {
57-
# Sampling
58-
"PatchModelAddDownscale": "",
59-
}
60-
6156
class ModelDownscaleExtension(ComfyExtension):
6257
@override
6358
async def get_node_list(self) -> list[type[io.ComfyNode]]:

0 commit comments

Comments
 (0)