|
2 | 2 | from torch import einsum
|
3 | 3 | import torch.nn.functional as F
|
4 | 4 | import math
|
| 5 | +from typing_extensions import override |
5 | 6 |
|
6 | 7 | from einops import rearrange, repeat
|
7 | 8 | from comfy.ldm.modules.attention import optimized_attention
|
8 | 9 | import comfy.samplers
|
| 10 | +from comfy_api.latest import ComfyExtension, io |
| 11 | + |
9 | 12 |
|
10 | 13 | # from comfy/ldm/modules/attention.py
|
11 | 14 | # but modified to return attention scores as well as output
|
@@ -104,19 +107,26 @@ def gaussian_blur_2d(img, kernel_size, sigma):
|
104 | 107 | img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
105 | 108 | return img
|
106 | 109 |
|
107 |
| -class SelfAttentionGuidance: |
| 110 | +class SelfAttentionGuidance(io.ComfyNode): |
108 | 111 | @classmethod
|
109 |
| - def INPUT_TYPES(s): |
110 |
| - return {"required": { "model": ("MODEL",), |
111 |
| - "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), |
112 |
| - "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), |
113 |
| - }} |
114 |
| - RETURN_TYPES = ("MODEL",) |
115 |
| - FUNCTION = "patch" |
116 |
| - |
117 |
| - CATEGORY = "_for_testing" |
| 112 | + def define_schema(cls): |
| 113 | + return io.Schema( |
| 114 | + node_id="SelfAttentionGuidance", |
| 115 | + display_name="Self-Attention Guidance", |
| 116 | + category="_for_testing", |
| 117 | + inputs=[ |
| 118 | + io.Model.Input("model"), |
| 119 | + io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), |
| 120 | + io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1), |
| 121 | + ], |
| 122 | + outputs=[ |
| 123 | + io.Model.Output(), |
| 124 | + ], |
| 125 | + is_experimental=True, |
| 126 | + ) |
118 | 127 |
|
119 |
| - def patch(self, model, scale, blur_sigma): |
| 128 | + @classmethod |
| 129 | + def execute(cls, model, scale, blur_sigma): |
120 | 130 | m = model.clone()
|
121 | 131 |
|
122 | 132 | attn_scores = None
|
@@ -170,12 +180,16 @@ def post_cfg_function(args):
|
170 | 180 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
171 | 181 | m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
|
172 | 182 |
|
173 |
| - return (m, ) |
| 183 | + return io.NodeOutput(m) |
| 184 | + |
| 185 | + |
| 186 | +class SagExtension(ComfyExtension): |
| 187 | + @override |
| 188 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 189 | + return [ |
| 190 | + SelfAttentionGuidance, |
| 191 | + ] |
174 | 192 |
|
175 |
| -NODE_CLASS_MAPPINGS = { |
176 |
| - "SelfAttentionGuidance": SelfAttentionGuidance, |
177 |
| -} |
178 | 193 |
|
179 |
| -NODE_DISPLAY_NAME_MAPPINGS = { |
180 |
| - "SelfAttentionGuidance": "Self-Attention Guidance", |
181 |
| -} |
| 194 | +async def comfy_entrypoint() -> SagExtension: |
| 195 | + return SagExtension() |
0 commit comments