Skip to content

Commit 2dadb34

Browse files
authored
convert nodes_hypertile.py to V3 schema (#10061)
1 parent 1cf86f5 commit 2dadb34

File tree

1 file changed

+39
-22
lines changed

1 file changed

+39
-22
lines changed

comfy_extras/nodes_hypertile.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#Taken from: https://github.com/tfernd/HyperTile/
22

33
import math
4+
from typing_extensions import override
45
from einops import rearrange
56
# Use torch rng for consistency across generations
67
from torch import randint
8+
from comfy_api.latest import ComfyExtension, io
79

810
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
911
min_value = min(min_value, value)
@@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
2022

2123
return ns[idx]
2224

23-
class HyperTile:
25+
class HyperTile(io.ComfyNode):
2426
@classmethod
25-
def INPUT_TYPES(s):
26-
return {"required": { "model": ("MODEL",),
27-
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
28-
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
29-
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
30-
"scale_depth": ("BOOLEAN", {"default": False}),
31-
}}
32-
RETURN_TYPES = ("MODEL",)
33-
FUNCTION = "patch"
34-
35-
CATEGORY = "model_patches/unet"
36-
37-
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
27+
def define_schema(cls):
28+
return io.Schema(
29+
node_id="HyperTile",
30+
category="model_patches/unet",
31+
inputs=[
32+
io.Model.Input("model"),
33+
io.Int.Input("tile_size", default=256, min=1, max=2048),
34+
io.Int.Input("swap_size", default=2, min=1, max=128),
35+
io.Int.Input("max_depth", default=0, min=0, max=10),
36+
io.Boolean.Input("scale_depth", default=False),
37+
],
38+
outputs=[
39+
io.Model.Output(),
40+
],
41+
)
42+
43+
@classmethod
44+
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
3845
latent_tile_size = max(32, tile_size) // 8
39-
self.temp = None
46+
temp = None
4047

4148
def hypertile_in(q, k, v, extra_options):
49+
nonlocal temp
4250
model_chans = q.shape[-2]
4351
orig_shape = extra_options['original_shape']
4452
apply_to = []
@@ -58,14 +66,15 @@ def hypertile_in(q, k, v, extra_options):
5866

5967
if nh * nw > 1:
6068
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
61-
self.temp = (nh, nw, h, w)
69+
temp = (nh, nw, h, w)
6270
return q, k, v
6371

6472
return q, k, v
6573
def hypertile_out(out, extra_options):
66-
if self.temp is not None:
67-
nh, nw, h, w = self.temp
68-
self.temp = None
74+
nonlocal temp
75+
if temp is not None:
76+
nh, nw, h, w = temp
77+
temp = None
6978
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
7079
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
7180
return out
@@ -76,6 +85,14 @@ def hypertile_out(out, extra_options):
7685
m.set_model_attn1_output_patch(hypertile_out)
7786
return (m, )
7887

79-
NODE_CLASS_MAPPINGS = {
80-
"HyperTile": HyperTile,
81-
}
88+
89+
class HyperTileExtension(ComfyExtension):
90+
@override
91+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
92+
return [
93+
HyperTile,
94+
]
95+
96+
97+
async def comfy_entrypoint() -> HyperTileExtension:
98+
return HyperTileExtension()

0 commit comments

Comments
 (0)