Skip to content

Commit a061b06

Browse files
authored
convert nodes_tcfg.py to V3 schema (#9942)
1 parent 8071890 commit a061b06

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

comfy_extras/nodes_tcfg.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
22

3+
from typing_extensions import override
34
import torch
45

5-
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
6+
from comfy_api.latest import ComfyExtension, io
67

78

89
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
@@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso
2627
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
2728

2829

29-
class TCFG(ComfyNodeABC):
30+
class TCFG(io.ComfyNode):
3031
@classmethod
31-
def INPUT_TYPES(cls) -> InputTypeDict:
32-
return {
33-
"required": {
34-
"model": (IO.MODEL, {}),
35-
}
36-
}
32+
def define_schema(cls):
33+
return io.Schema(
34+
node_id="TCFG",
35+
display_name="Tangential Damping CFG",
36+
category="advanced/guidance",
37+
description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
38+
inputs=[
39+
io.Model.Input("model"),
40+
],
41+
outputs=[
42+
io.Model.Output(display_name="patched_model"),
43+
],
44+
)
3745

38-
RETURN_TYPES = (IO.MODEL,)
39-
RETURN_NAMES = ("patched_model",)
40-
FUNCTION = "patch"
41-
42-
CATEGORY = "advanced/guidance"
43-
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
44-
45-
def patch(self, model):
46+
@classmethod
47+
def execute(cls, model):
4648
m = model.clone()
4749

4850
def tangential_damping_cfg(args):
@@ -59,13 +61,16 @@ def tangential_damping_cfg(args):
5961
return [cond_pred, uncond_pred_td] + conds_out[2:]
6062

6163
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
62-
return (m,)
64+
return io.NodeOutput(m)
65+
6366

67+
class TcfgExtension(ComfyExtension):
68+
@override
69+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
70+
return [
71+
TCFG,
72+
]
6473

65-
NODE_CLASS_MAPPINGS = {
66-
"TCFG": TCFG,
67-
}
6874

69-
NODE_DISPLAY_NAME_MAPPINGS = {
70-
"TCFG": "Tangential Damping CFG",
71-
}
75+
async def comfy_entrypoint() -> TcfgExtension:
76+
return TcfgExtension()

0 commit comments

Comments
 (0)