1
- class EpsilonScaling :
1
+ from typing_extensions import override
2
+
3
+ from comfy_api .latest import ComfyExtension , io
4
+
5
+
6
+ class EpsilonScaling (io .ComfyNode ):
2
7
"""
3
8
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
4
9
(https://arxiv.org/abs/2308.15321v6).
@@ -8,26 +13,28 @@ class EpsilonScaling:
8
13
recommended by the paper for its practicality and effectiveness.
9
14
"""
10
15
@classmethod
11
- def INPUT_TYPES (s ):
12
- return {
13
- "required" : {
14
- "model" : ("MODEL" ,),
15
- "scaling_factor" : ("FLOAT" , {
16
- "default" : 1.005 ,
17
- "min" : 0.5 ,
18
- "max" : 1.5 ,
19
- "step" : 0.001 ,
20
- "display" : "number"
21
- }),
22
- }
23
- }
24
-
25
- RETURN_TYPES = ("MODEL" ,)
26
- FUNCTION = "patch"
27
-
28
- CATEGORY = "model_patches/unet"
29
-
30
- def patch (self , model , scaling_factor ):
16
+ def define_schema (cls ):
17
+ return io .Schema (
18
+ node_id = "Epsilon Scaling" ,
19
+ category = "model_patches/unet" ,
20
+ inputs = [
21
+ io .Model .Input ("model" ),
22
+ io .Float .Input (
23
+ "scaling_factor" ,
24
+ default = 1.005 ,
25
+ min = 0.5 ,
26
+ max = 1.5 ,
27
+ step = 0.001 ,
28
+ display_mode = io .NumberDisplay .number ,
29
+ ),
30
+ ],
31
+ outputs = [
32
+ io .Model .Output (),
33
+ ],
34
+ )
35
+
36
+ @classmethod
37
+ def execute (cls , model , scaling_factor ) -> io .NodeOutput :
31
38
# Prevent division by zero, though the UI's min value should prevent this.
32
39
if scaling_factor == 0 :
33
40
scaling_factor = 1e-9
@@ -53,8 +60,15 @@ def epsilon_scaling_function(args):
53
60
54
61
model_clone .set_model_sampler_post_cfg_function (epsilon_scaling_function )
55
62
56
- return (model_clone ,)
63
+ return io .NodeOutput (model_clone )
64
+
65
+
66
+ class EpsilonScalingExtension (ComfyExtension ):
67
+ @override
68
+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
69
+ return [
70
+ EpsilonScaling ,
71
+ ]
57
72
58
- NODE_CLASS_MAPPINGS = {
59
- "Epsilon Scaling" : EpsilonScaling
60
- }
73
+ async def comfy_entrypoint () -> EpsilonScalingExtension :
74
+ return EpsilonScalingExtension ()
0 commit comments