Skip to content

Commit 265b4f0

Browse files
committed
init
1 parent 6ca3d5c commit 265b4f0

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

comfy_extras/nodes_sage3.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Callable
2+
3+
import torch
4+
from typing_extensions import override
5+
6+
from comfy.ldm.modules.attention import get_attention_function
7+
from comfy.model_patcher import ModelPatcher
8+
from comfy_api.latest import ComfyExtension, io
9+
from server import PromptServer
10+
11+
12+
class Sage3PatchModel(io.ComfyNode):
13+
@classmethod
14+
def define_schema(cls):
15+
return io.Schema(
16+
node_id="Sage3PatchModel",
17+
display_name="Patch SageAttention 3",
18+
description="Apply `attention3_sage` to the middle blocks and steps, while using optimized_attention for the first/last blocks and steps",
19+
category="_for_testing",
20+
inputs=[
21+
io.Model.Input("model"),
22+
],
23+
outputs=[io.Model.Output()],
24+
is_experimental=True,
25+
)
26+
27+
@classmethod
28+
def execute(cls, model: ModelPatcher) -> io.NodeOutput:
29+
sage3: Callable | None = get_attention_function("sage3", default=None)
30+
31+
if sage3 is None:
32+
PromptServer.instance.send_progress_text(
33+
"`sageattn3` is not installed / available...",
34+
cls.hidden.unique_id,
35+
)
36+
return io.NodeOutput(model)
37+
38+
def attention_override(func: Callable, *args, **kwargs):
39+
transformer_options: dict = kwargs.get("transformer_options", {})
40+
41+
block_index: int = transformer_options.get("block_index", 0)
42+
total_blocks: int = transformer_options.get("total_blocks", 1)
43+
44+
if block_index == 0 or block_index >= (total_blocks - 1):
45+
return func(*args, **kwargs)
46+
47+
sample_sigmas: torch.Tensor = transformer_options["sample_sigmas"]
48+
sigmas: torch.Tensor = transformer_options["sigmas"]
49+
50+
total_steps: int = sample_sigmas.size(0)
51+
step: int = 0
52+
53+
for i in range(total_steps):
54+
if torch.allclose(sample_sigmas[i], sigmas):
55+
step = i
56+
break
57+
58+
if step == 0 or step >= (total_steps - 1):
59+
return func(*args, **kwargs)
60+
61+
return sage3(*args, **kwargs)
62+
63+
model = model.clone()
64+
model.model_options["transformer_options"][
65+
"optimized_attention_override"
66+
] = attention_override
67+
68+
return io.NodeOutput(model)
69+
70+
71+
class Sage3Extension(ComfyExtension):
72+
@override
73+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
74+
return [Sage3PatchModel]
75+
76+
77+
async def comfy_entrypoint():
78+
return Sage3Extension()

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2360,6 +2360,7 @@ async def init_builtin_extra_nodes():
23602360
"nodes_nop.py",
23612361
"nodes_kandinsky5.py",
23622362
"nodes_wanmove.py",
2363+
"nodes_sage3.py",
23632364
]
23642365

23652366
import_failed = []

0 commit comments

Comments
 (0)