1
1
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
2
2
3
+ from typing_extensions import override
3
4
import torch
4
5
5
- from comfy . comfy_types import IO , ComfyNodeABC , InputTypeDict
6
+ from comfy_api . latest import ComfyExtension , io
6
7
7
8
8
9
def score_tangential_damping (cond_score : torch .Tensor , uncond_score : torch .Tensor ) -> torch .Tensor :
@@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso
26
27
return uncond_score_td .reshape_as (uncond_score ).to (uncond_score .dtype )
27
28
28
29
29
- class TCFG (ComfyNodeABC ):
30
+ class TCFG (io . ComfyNode ):
30
31
@classmethod
31
- def INPUT_TYPES (cls ) -> InputTypeDict :
32
- return {
33
- "required" : {
34
- "model" : (IO .MODEL , {}),
35
- }
36
- }
32
+ def define_schema (cls ):
33
+ return io .Schema (
34
+ node_id = "TCFG" ,
35
+ display_name = "Tangential Damping CFG" ,
36
+ category = "advanced/guidance" ,
37
+ description = "TCFG – Tangential Damping CFG (2503.18137)\n \n Refine the uncond (negative) to align with the cond (positive) for improving quality." ,
38
+ inputs = [
39
+ io .Model .Input ("model" ),
40
+ ],
41
+ outputs = [
42
+ io .Model .Output (display_name = "patched_model" ),
43
+ ],
44
+ )
37
45
38
- RETURN_TYPES = (IO .MODEL ,)
39
- RETURN_NAMES = ("patched_model" ,)
40
- FUNCTION = "patch"
41
-
42
- CATEGORY = "advanced/guidance"
43
- DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n \n Refine the uncond (negative) to align with the cond (positive) for improving quality."
44
-
45
- def patch (self , model ):
46
+ @classmethod
47
+ def execute (cls , model ):
46
48
m = model .clone ()
47
49
48
50
def tangential_damping_cfg (args ):
@@ -59,13 +61,16 @@ def tangential_damping_cfg(args):
59
61
return [cond_pred , uncond_pred_td ] + conds_out [2 :]
60
62
61
63
m .set_model_sampler_pre_cfg_function (tangential_damping_cfg )
62
- return (m ,)
64
+ return io .NodeOutput (m )
65
+
63
66
67
+ class TcfgExtension (ComfyExtension ):
68
+ @override
69
+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
70
+ return [
71
+ TCFG ,
72
+ ]
64
73
65
- NODE_CLASS_MAPPINGS = {
66
- "TCFG" : TCFG ,
67
- }
68
74
69
- NODE_DISPLAY_NAME_MAPPINGS = {
70
- "TCFG" : "Tangential Damping CFG" ,
71
- }
75
+ async def comfy_entrypoint () -> TcfgExtension :
76
+ return TcfgExtension ()
0 commit comments