Skip to content

Commit d20576e

Browse files
authored
convert nodes_sag.py to V3 schema (#9940)
1 parent a061b06 commit d20576e

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

comfy_extras/nodes_sag.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
from torch import einsum
33
import torch.nn.functional as F
44
import math
5+
from typing_extensions import override
56

67
from einops import rearrange, repeat
78
from comfy.ldm.modules.attention import optimized_attention
89
import comfy.samplers
10+
from comfy_api.latest import ComfyExtension, io
11+
912

1013
# from comfy/ldm/modules/attention.py
1114
# but modified to return attention scores as well as output
@@ -104,19 +107,26 @@ def gaussian_blur_2d(img, kernel_size, sigma):
104107
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
105108
return img
106109

107-
class SelfAttentionGuidance:
110+
class SelfAttentionGuidance(io.ComfyNode):
108111
@classmethod
109-
def INPUT_TYPES(s):
110-
return {"required": { "model": ("MODEL",),
111-
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}),
112-
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
113-
}}
114-
RETURN_TYPES = ("MODEL",)
115-
FUNCTION = "patch"
116-
117-
CATEGORY = "_for_testing"
112+
def define_schema(cls):
113+
return io.Schema(
114+
node_id="SelfAttentionGuidance",
115+
display_name="Self-Attention Guidance",
116+
category="_for_testing",
117+
inputs=[
118+
io.Model.Input("model"),
119+
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
120+
io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1),
121+
],
122+
outputs=[
123+
io.Model.Output(),
124+
],
125+
is_experimental=True,
126+
)
118127

119-
def patch(self, model, scale, blur_sigma):
128+
@classmethod
129+
def execute(cls, model, scale, blur_sigma):
120130
m = model.clone()
121131

122132
attn_scores = None
@@ -170,12 +180,16 @@ def post_cfg_function(args):
170180
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
171181
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
172182

173-
return (m, )
183+
return io.NodeOutput(m)
184+
185+
186+
class SagExtension(ComfyExtension):
187+
@override
188+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
189+
return [
190+
SelfAttentionGuidance,
191+
]
174192

175-
NODE_CLASS_MAPPINGS = {
176-
"SelfAttentionGuidance": SelfAttentionGuidance,
177-
}
178193

179-
NODE_DISPLAY_NAME_MAPPINGS = {
180-
"SelfAttentionGuidance": "Self-Attention Guidance",
181-
}
194+
async def comfy_entrypoint() -> SagExtension:
195+
return SagExtension()

0 commit comments

Comments
 (0)