Skip to content

Commit 7ea173c

Browse files
authored
convert nodes_fresca.py to V3 schema (#9951)
1 parent 76eb1d7 commit 7ea173c

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

comfy_extras/nodes_fresca.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Code based on https://github.com/WikiChao/FreSca (MIT License)
22
import torch
33
import torch.fft as fft
4+
from typing_extensions import override
5+
from comfy_api.latest import ComfyExtension, io
46

57

68
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
@@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
5153
return x_filtered
5254

5355

54-
class FreSca:
56+
class FreSca(io.ComfyNode):
5557
@classmethod
56-
def INPUT_TYPES(s):
57-
return {
58-
"required": {
59-
"model": ("MODEL",),
60-
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
61-
"tooltip": "Scaling factor for low-frequency components"}),
62-
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
63-
"tooltip": "Scaling factor for high-frequency components"}),
64-
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
65-
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
66-
}
67-
}
68-
RETURN_TYPES = ("MODEL",)
69-
FUNCTION = "patch"
70-
CATEGORY = "_for_testing"
71-
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
72-
def patch(self, model, scale_low, scale_high, freq_cutoff):
58+
def define_schema(cls):
59+
return io.Schema(
60+
node_id="FreSca",
61+
display_name="FreSca",
62+
category="_for_testing",
63+
description="Applies frequency-dependent scaling to the guidance",
64+
inputs=[
65+
io.Model.Input("model"),
66+
io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
67+
tooltip="Scaling factor for low-frequency components"),
68+
io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
69+
tooltip="Scaling factor for high-frequency components"),
70+
io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
71+
tooltip="Number of frequency indices around center to consider as low-frequency"),
72+
],
73+
outputs=[
74+
io.Model.Output(),
75+
],
76+
is_experimental=True,
77+
)
78+
79+
@classmethod
80+
def execute(cls, model, scale_low, scale_high, freq_cutoff):
7381
def custom_cfg_function(args):
7482
conds_out = args["conds_out"]
7583
if len(conds_out) <= 1 or None in args["conds"][:2]:
@@ -91,13 +99,16 @@ def custom_cfg_function(args):
9199
m = model.clone()
92100
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
93101

94-
return (m,)
102+
return io.NodeOutput(m)
103+
95104

105+
class FreScaExtension(ComfyExtension):
106+
@override
107+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
108+
return [
109+
FreSca,
110+
]
96111

97-
NODE_CLASS_MAPPINGS = {
98-
"FreSca": FreSca,
99-
}
100112

101-
NODE_DISPLAY_NAME_MAPPINGS = {
102-
"FreSca": "FreSca",
103-
}
113+
async def comfy_entrypoint() -> FreScaExtension:
114+
return FreScaExtension()

0 commit comments

Comments
 (0)