File tree Expand file tree Collapse file tree 2 files changed +61
-0
lines changed Expand file tree Collapse file tree 2 files changed +61
-0
lines changed Original file line number Diff line number Diff line change
1
+ class EpsilonScaling :
2
+ """
3
+ Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
4
+ (https://arxiv.org/abs/2308.15321v6).
5
+
6
+ This method mitigates exposure bias by scaling the predicted noise during sampling,
7
+ which can significantly improve sample quality. This implementation uses the "uniform schedule"
8
+ recommended by the paper for its practicality and effectiveness.
9
+ """
10
+ @classmethod
11
+ def INPUT_TYPES (s ):
12
+ return {
13
+ "required" : {
14
+ "model" : ("MODEL" ,),
15
+ "scaling_factor" : ("FLOAT" , {
16
+ "default" : 1.005 ,
17
+ "min" : 0.5 ,
18
+ "max" : 1.5 ,
19
+ "step" : 0.001 ,
20
+ "display" : "number"
21
+ }),
22
+ }
23
+ }
24
+
25
+ RETURN_TYPES = ("MODEL" ,)
26
+ FUNCTION = "patch"
27
+
28
+ CATEGORY = "model_patches/unet"
29
+
30
+ def patch (self , model , scaling_factor ):
31
+ # Prevent division by zero, though the UI's min value should prevent this.
32
+ if scaling_factor == 0 :
33
+ scaling_factor = 1e-9
34
+
35
+ def epsilon_scaling_function (args ):
36
+ """
37
+ This function is applied after the CFG guidance has been calculated.
38
+ It recalculates the denoised latent by scaling the predicted noise.
39
+ """
40
+ denoised = args ["denoised" ]
41
+ x = args ["input" ]
42
+
43
+ noise_pred = x - denoised
44
+
45
+ scaled_noise_pred = noise_pred / scaling_factor
46
+
47
+ new_denoised = x - scaled_noise_pred
48
+
49
+ return new_denoised
50
+
51
+ # Clone the model patcher to avoid modifying the original model in place
52
+ model_clone = model .clone ()
53
+
54
+ model_clone .set_model_sampler_post_cfg_function (epsilon_scaling_function )
55
+
56
+ return (model_clone ,)
57
+
58
+ NODE_CLASS_MAPPINGS = {
59
+ "Epsilon Scaling" : EpsilonScaling
60
+ }
Original file line number Diff line number Diff line change @@ -2297,6 +2297,7 @@ async def init_builtin_extra_nodes():
2297
2297
"nodes_gits.py" ,
2298
2298
"nodes_controlnet.py" ,
2299
2299
"nodes_hunyuan.py" ,
2300
+ "nodes_eps.py" ,
2300
2301
"nodes_flux.py" ,
2301
2302
"nodes_lora_extract.py" ,
2302
2303
"nodes_torch_compile.py" ,
You can’t perform that action at this time.
0 commit comments