22
33import torch
44import logging
5+ from typing_extensions import override
6+ from comfy_api .latest import ComfyExtension , IO
57
68def Fourier_filter (x , threshold , scale ):
79 # FFT
@@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
2224 return x_filtered .to (x .dtype )
2325
2426
25- class FreeU :
27+ class FreeU ( IO . ComfyNode ) :
2628 @classmethod
27- def INPUT_TYPES (s ):
28- return {"required" : { "model" : ("MODEL" ,),
29- "b1" : ("FLOAT" , {"default" : 1.1 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
30- "b2" : ("FLOAT" , {"default" : 1.2 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
31- "s1" : ("FLOAT" , {"default" : 0.9 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
32- "s2" : ("FLOAT" , {"default" : 0.2 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
33- }}
34- RETURN_TYPES = ("MODEL" ,)
35- FUNCTION = "patch"
36-
37- CATEGORY = "model_patches/unet"
38-
39- def patch (self , model , b1 , b2 , s1 , s2 ):
29+ def define_schema (cls ):
30+ return IO .Schema (
31+ node_id = "FreeU" ,
32+ category = "model_patches/unet" ,
33+ inputs = [
34+ IO .Model .Input ("model" ),
35+ IO .Float .Input ("b1" , default = 1.1 , min = 0.0 , max = 10.0 , step = 0.01 ),
36+ IO .Float .Input ("b2" , default = 1.2 , min = 0.0 , max = 10.0 , step = 0.01 ),
37+ IO .Float .Input ("s1" , default = 0.9 , min = 0.0 , max = 10.0 , step = 0.01 ),
38+ IO .Float .Input ("s2" , default = 0.2 , min = 0.0 , max = 10.0 , step = 0.01 ),
39+ ],
40+ outputs = [
41+ IO .Model .Output (),
42+ ],
43+ )
44+
45+ @classmethod
46+ def execute (cls , model , b1 , b2 , s1 , s2 ) -> IO .NodeOutput :
4047 model_channels = model .model .model_config .unet_config ["model_channels" ]
4148 scale_dict = {model_channels * 4 : (b1 , s1 ), model_channels * 2 : (b2 , s2 )}
4249 on_cpu_devices = {}
@@ -59,23 +66,31 @@ def output_block_patch(h, hsp, transformer_options):
5966
6067 m = model .clone ()
6168 m .set_model_output_block_patch (output_block_patch )
62- return (m , )
69+ return IO .NodeOutput (m )
70+
71+ patch = execute # TODO: remove
72+
73+
74+ class FreeU_V2 (IO .ComfyNode ):
75+ @classmethod
76+ def define_schema (cls ):
77+ return IO .Schema (
78+ node_id = "FreeU_V2" ,
79+ category = "model_patches/unet" ,
80+ inputs = [
81+ IO .Model .Input ("model" ),
82+ IO .Float .Input ("b1" , default = 1.3 , min = 0.0 , max = 10.0 , step = 0.01 ),
83+ IO .Float .Input ("b2" , default = 1.4 , min = 0.0 , max = 10.0 , step = 0.01 ),
84+ IO .Float .Input ("s1" , default = 0.9 , min = 0.0 , max = 10.0 , step = 0.01 ),
85+ IO .Float .Input ("s2" , default = 0.2 , min = 0.0 , max = 10.0 , step = 0.01 ),
86+ ],
87+ outputs = [
88+ IO .Model .Output (),
89+ ],
90+ )
6391
64- class FreeU_V2 :
6592 @classmethod
66- def INPUT_TYPES (s ):
67- return {"required" : { "model" : ("MODEL" ,),
68- "b1" : ("FLOAT" , {"default" : 1.3 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
69- "b2" : ("FLOAT" , {"default" : 1.4 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
70- "s1" : ("FLOAT" , {"default" : 0.9 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
71- "s2" : ("FLOAT" , {"default" : 0.2 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
72- }}
73- RETURN_TYPES = ("MODEL" ,)
74- FUNCTION = "patch"
75-
76- CATEGORY = "model_patches/unet"
77-
78- def patch (self , model , b1 , b2 , s1 , s2 ):
93+ def execute (cls , model , b1 , b2 , s1 , s2 ) -> IO .NodeOutput :
7994 model_channels = model .model .model_config .unet_config ["model_channels" ]
8095 scale_dict = {model_channels * 4 : (b1 , s1 ), model_channels * 2 : (b2 , s2 )}
8196 on_cpu_devices = {}
@@ -105,9 +120,19 @@ def output_block_patch(h, hsp, transformer_options):
105120
106121 m = model .clone ()
107122 m .set_model_output_block_patch (output_block_patch )
108- return (m , )
123+ return IO .NodeOutput (m )
124+
125+ patch = execute # TODO: remove
126+
127+
128+ class FreelunchExtension (ComfyExtension ):
129+ @override
130+ async def get_node_list (self ) -> list [type [IO .ComfyNode ]]:
131+ return [
132+ FreeU ,
133+ FreeU_V2 ,
134+ ]
135+
109136
110- NODE_CLASS_MAPPINGS = {
111- "FreeU" : FreeU ,
112- "FreeU_V2" : FreeU_V2 ,
113- }
137+ async def comfy_entrypoint () -> FreelunchExtension :
138+ return FreelunchExtension ()
0 commit comments