1
1
# Code based on https://github.com/WikiChao/FreSca (MIT License)
2
2
import torch
3
3
import torch .fft as fft
4
+ from typing_extensions import override
5
+ from comfy_api .latest import ComfyExtension , io
4
6
5
7
6
8
def Fourier_filter (x , scale_low = 1.0 , scale_high = 1.5 , freq_cutoff = 20 ):
@@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
51
53
return x_filtered
52
54
53
55
54
- class FreSca :
56
+ class FreSca ( io . ComfyNode ) :
55
57
@classmethod
56
- def INPUT_TYPES (s ):
57
- return {
58
- "required" : {
59
- "model" : ("MODEL" ,),
60
- "scale_low" : ("FLOAT" , {"default" : 1.0 , "min" : 0 , "max" : 10 , "step" : 0.01 ,
61
- "tooltip" : "Scaling factor for low-frequency components" }),
62
- "scale_high" : ("FLOAT" , {"default" : 1.25 , "min" : 0 , "max" : 10 , "step" : 0.01 ,
63
- "tooltip" : "Scaling factor for high-frequency components" }),
64
- "freq_cutoff" : ("INT" , {"default" : 20 , "min" : 1 , "max" : 10000 , "step" : 1 ,
65
- "tooltip" : "Number of frequency indices around center to consider as low-frequency" }),
66
- }
67
- }
68
- RETURN_TYPES = ("MODEL" ,)
69
- FUNCTION = "patch"
70
- CATEGORY = "_for_testing"
71
- DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
72
- def patch (self , model , scale_low , scale_high , freq_cutoff ):
58
+ def define_schema (cls ):
59
+ return io .Schema (
60
+ node_id = "FreSca" ,
61
+ display_name = "FreSca" ,
62
+ category = "_for_testing" ,
63
+ description = "Applies frequency-dependent scaling to the guidance" ,
64
+ inputs = [
65
+ io .Model .Input ("model" ),
66
+ io .Float .Input ("scale_low" , default = 1.0 , min = 0 , max = 10 , step = 0.01 ,
67
+ tooltip = "Scaling factor for low-frequency components" ),
68
+ io .Float .Input ("scale_high" , default = 1.25 , min = 0 , max = 10 , step = 0.01 ,
69
+ tooltip = "Scaling factor for high-frequency components" ),
70
+ io .Int .Input ("freq_cutoff" , default = 20 , min = 1 , max = 10000 , step = 1 ,
71
+ tooltip = "Number of frequency indices around center to consider as low-frequency" ),
72
+ ],
73
+ outputs = [
74
+ io .Model .Output (),
75
+ ],
76
+ is_experimental = True ,
77
+ )
78
+
79
+ @classmethod
80
+ def execute (cls , model , scale_low , scale_high , freq_cutoff ):
73
81
def custom_cfg_function (args ):
74
82
conds_out = args ["conds_out" ]
75
83
if len (conds_out ) <= 1 or None in args ["conds" ][:2 ]:
@@ -91,13 +99,16 @@ def custom_cfg_function(args):
91
99
m = model .clone ()
92
100
m .set_model_sampler_pre_cfg_function (custom_cfg_function )
93
101
94
- return (m ,)
102
+ return io .NodeOutput (m )
103
+
95
104
105
+ class FreScaExtension (ComfyExtension ):
106
+ @override
107
+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
108
+ return [
109
+ FreSca ,
110
+ ]
96
111
97
- NODE_CLASS_MAPPINGS = {
98
- "FreSca" : FreSca ,
99
- }
100
112
101
- NODE_DISPLAY_NAME_MAPPINGS = {
102
- "FreSca" : "FreSca" ,
103
- }
113
+ async def comfy_entrypoint () -> FreScaExtension :
114
+ return FreScaExtension ()
0 commit comments