Skip to content

Commit f48d05a

Browse files
authored
convert AlignYourStepsScheduler node to V3 schema (#9226)
1 parent 4368d8f commit f48d05a

File tree

1 file changed

+33
-17
lines changed

1 file changed

+33
-17
lines changed
Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
22
import numpy as np
33
import torch
4+
from typing_extensions import override
5+
6+
from comfy_api.latest import ComfyExtension, io
7+
48

59
def loglinear_interp(t_steps, num_steps):
610
"""
@@ -19,25 +23,30 @@ def loglinear_interp(t_steps, num_steps):
1923
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
2024
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
2125

22-
class AlignYourStepsScheduler:
26+
class AlignYourStepsScheduler(io.ComfyNode):
2327
@classmethod
24-
def INPUT_TYPES(s):
25-
return {"required":
26-
{"model_type": (["SD1", "SDXL", "SVD"], ),
27-
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
28-
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
29-
}
30-
}
31-
RETURN_TYPES = ("SIGMAS",)
32-
CATEGORY = "sampling/custom_sampling/schedulers"
33-
34-
FUNCTION = "get_sigmas"
28+
def define_schema(cls) -> io.Schema:
29+
return io.Schema(
30+
node_id="AlignYourStepsScheduler",
31+
category="sampling/custom_sampling/schedulers",
32+
inputs=[
33+
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
34+
io.Int.Input("steps", default=10, min=1, max=10000),
35+
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
36+
],
37+
outputs=[io.Sigmas.Output()],
38+
)
3539

3640
def get_sigmas(self, model_type, steps, denoise):
41+
# Deprecated: use the V3 schema's `execute` method instead of this.
42+
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
43+
44+
@classmethod
45+
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
3746
total_steps = steps
3847
if denoise < 1.0:
3948
if denoise <= 0.0:
40-
return (torch.FloatTensor([]),)
49+
return io.NodeOutput(torch.FloatTensor([]))
4150
total_steps = round(steps * denoise)
4251

4352
sigmas = NOISE_LEVELS[model_type][:]
@@ -46,8 +55,15 @@ def get_sigmas(self, model_type, steps, denoise):
4655

4756
sigmas = sigmas[-(total_steps + 1):]
4857
sigmas[-1] = 0
49-
return (torch.FloatTensor(sigmas), )
58+
return io.NodeOutput(torch.FloatTensor(sigmas))
59+
60+
61+
class AlignYourStepsExtension(ComfyExtension):
62+
@override
63+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
64+
return [
65+
AlignYourStepsScheduler,
66+
]
5067

51-
NODE_CLASS_MAPPINGS = {
52-
"AlignYourStepsScheduler": AlignYourStepsScheduler,
53-
}
68+
async def comfy_entrypoint() -> AlignYourStepsExtension:
69+
return AlignYourStepsExtension()

0 commit comments

Comments
 (0)