Skip to content

Commit bb32d4e

Browse files
authored
feat: Add Epsilon Scaling node for exposure bias correction (#10132)
1 parent a6f83a4 commit bb32d4e

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

comfy_extras/nodes_eps.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
class EpsilonScaling:
2+
"""
3+
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
4+
(https://arxiv.org/abs/2308.15321v6).
5+
6+
This method mitigates exposure bias by scaling the predicted noise during sampling,
7+
which can significantly improve sample quality. This implementation uses the "uniform schedule"
8+
recommended by the paper for its practicality and effectiveness.
9+
"""
10+
@classmethod
11+
def INPUT_TYPES(s):
12+
return {
13+
"required": {
14+
"model": ("MODEL",),
15+
"scaling_factor": ("FLOAT", {
16+
"default": 1.005,
17+
"min": 0.5,
18+
"max": 1.5,
19+
"step": 0.001,
20+
"display": "number"
21+
}),
22+
}
23+
}
24+
25+
RETURN_TYPES = ("MODEL",)
26+
FUNCTION = "patch"
27+
28+
CATEGORY = "model_patches/unet"
29+
30+
def patch(self, model, scaling_factor):
31+
# Prevent division by zero, though the UI's min value should prevent this.
32+
if scaling_factor == 0:
33+
scaling_factor = 1e-9
34+
35+
def epsilon_scaling_function(args):
36+
"""
37+
This function is applied after the CFG guidance has been calculated.
38+
It recalculates the denoised latent by scaling the predicted noise.
39+
"""
40+
denoised = args["denoised"]
41+
x = args["input"]
42+
43+
noise_pred = x - denoised
44+
45+
scaled_noise_pred = noise_pred / scaling_factor
46+
47+
new_denoised = x - scaled_noise_pred
48+
49+
return new_denoised
50+
51+
# Clone the model patcher to avoid modifying the original model in place
52+
model_clone = model.clone()
53+
54+
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
55+
56+
return (model_clone,)
57+
58+
NODE_CLASS_MAPPINGS = {
59+
"Epsilon Scaling": EpsilonScaling
60+
}

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,6 +2297,7 @@ async def init_builtin_extra_nodes():
22972297
"nodes_gits.py",
22982298
"nodes_controlnet.py",
22992299
"nodes_hunyuan.py",
2300+
"nodes_eps.py",
23002301
"nodes_flux.py",
23012302
"nodes_lora_extract.py",
23022303
"nodes_torch_compile.py",

0 commit comments

Comments
 (0)