Skip to content

Commit d7aa414

Browse files
authored
convert nodes_eps.py to V3 schema (#10172)
1 parent 3e68bc3 commit d7aa414

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

comfy_extras/nodes_eps.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
class EpsilonScaling:
1+
from typing_extensions import override
2+
3+
from comfy_api.latest import ComfyExtension, io
4+
5+
6+
class EpsilonScaling(io.ComfyNode):
27
"""
38
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
49
(https://arxiv.org/abs/2308.15321v6).
@@ -8,26 +13,28 @@ class EpsilonScaling:
813
recommended by the paper for its practicality and effectiveness.
914
"""
1015
@classmethod
11-
def INPUT_TYPES(s):
12-
return {
13-
"required": {
14-
"model": ("MODEL",),
15-
"scaling_factor": ("FLOAT", {
16-
"default": 1.005,
17-
"min": 0.5,
18-
"max": 1.5,
19-
"step": 0.001,
20-
"display": "number"
21-
}),
22-
}
23-
}
24-
25-
RETURN_TYPES = ("MODEL",)
26-
FUNCTION = "patch"
27-
28-
CATEGORY = "model_patches/unet"
29-
30-
def patch(self, model, scaling_factor):
16+
def define_schema(cls):
17+
return io.Schema(
18+
node_id="Epsilon Scaling",
19+
category="model_patches/unet",
20+
inputs=[
21+
io.Model.Input("model"),
22+
io.Float.Input(
23+
"scaling_factor",
24+
default=1.005,
25+
min=0.5,
26+
max=1.5,
27+
step=0.001,
28+
display_mode=io.NumberDisplay.number,
29+
),
30+
],
31+
outputs=[
32+
io.Model.Output(),
33+
],
34+
)
35+
36+
@classmethod
37+
def execute(cls, model, scaling_factor) -> io.NodeOutput:
3138
# Prevent division by zero, though the UI's min value should prevent this.
3239
if scaling_factor == 0:
3340
scaling_factor = 1e-9
@@ -53,8 +60,15 @@ def epsilon_scaling_function(args):
5360

5461
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
5562

56-
return (model_clone,)
63+
return io.NodeOutput(model_clone)
64+
65+
66+
class EpsilonScalingExtension(ComfyExtension):
67+
@override
68+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
69+
return [
70+
EpsilonScaling,
71+
]
5772

58-
NODE_CLASS_MAPPINGS = {
59-
"Epsilon Scaling": EpsilonScaling
60-
}
73+
async def comfy_entrypoint() -> EpsilonScalingExtension:
74+
return EpsilonScalingExtension()

0 commit comments

Comments
 (0)