1
1
#Taken from: https://github.com/tfernd/HyperTile/
2
2
3
3
import math
4
+ from typing_extensions import override
4
5
from einops import rearrange
5
6
# Use torch rng for consistency across generations
6
7
from torch import randint
8
+ from comfy_api .latest import ComfyExtension , io
7
9
8
10
def random_divisor (value : int , min_value : int , / , max_options : int = 1 ) -> int :
9
11
min_value = min (min_value , value )
@@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
20
22
21
23
return ns [idx ]
22
24
23
- class HyperTile :
25
+ class HyperTile ( io . ComfyNode ) :
24
26
@classmethod
25
- def INPUT_TYPES (s ):
26
- return {"required" : { "model" : ("MODEL" ,),
27
- "tile_size" : ("INT" , {"default" : 256 , "min" : 1 , "max" : 2048 }),
28
- "swap_size" : ("INT" , {"default" : 2 , "min" : 1 , "max" : 128 }),
29
- "max_depth" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 10 }),
30
- "scale_depth" : ("BOOLEAN" , {"default" : False }),
31
- }}
32
- RETURN_TYPES = ("MODEL" ,)
33
- FUNCTION = "patch"
34
-
35
- CATEGORY = "model_patches/unet"
36
-
37
- def patch (self , model , tile_size , swap_size , max_depth , scale_depth ):
27
+ def define_schema (cls ):
28
+ return io .Schema (
29
+ node_id = "HyperTile" ,
30
+ category = "model_patches/unet" ,
31
+ inputs = [
32
+ io .Model .Input ("model" ),
33
+ io .Int .Input ("tile_size" , default = 256 , min = 1 , max = 2048 ),
34
+ io .Int .Input ("swap_size" , default = 2 , min = 1 , max = 128 ),
35
+ io .Int .Input ("max_depth" , default = 0 , min = 0 , max = 10 ),
36
+ io .Boolean .Input ("scale_depth" , default = False ),
37
+ ],
38
+ outputs = [
39
+ io .Model .Output (),
40
+ ],
41
+ )
42
+
43
+ @classmethod
44
+ def execute (cls , model , tile_size , swap_size , max_depth , scale_depth ) -> io .NodeOutput :
38
45
latent_tile_size = max (32 , tile_size ) // 8
39
- self . temp = None
46
+ temp = None
40
47
41
48
def hypertile_in (q , k , v , extra_options ):
49
+ nonlocal temp
42
50
model_chans = q .shape [- 2 ]
43
51
orig_shape = extra_options ['original_shape' ]
44
52
apply_to = []
@@ -58,14 +66,15 @@ def hypertile_in(q, k, v, extra_options):
58
66
59
67
if nh * nw > 1 :
60
68
q = rearrange (q , "b (nh h nw w) c -> (b nh nw) (h w) c" , h = h // nh , w = w // nw , nh = nh , nw = nw )
61
- self . temp = (nh , nw , h , w )
69
+ temp = (nh , nw , h , w )
62
70
return q , k , v
63
71
64
72
return q , k , v
65
73
def hypertile_out (out , extra_options ):
66
- if self .temp is not None :
67
- nh , nw , h , w = self .temp
68
- self .temp = None
74
+ nonlocal temp
75
+ if temp is not None :
76
+ nh , nw , h , w = temp
77
+ temp = None
69
78
out = rearrange (out , "(b nh nw) hw c -> b nh nw hw c" , nh = nh , nw = nw )
70
79
out = rearrange (out , "b nh nw (h w) c -> b (nh h nw w) c" , h = h // nh , w = w // nw )
71
80
return out
@@ -76,6 +85,14 @@ def hypertile_out(out, extra_options):
76
85
m .set_model_attn1_output_patch (hypertile_out )
77
86
return (m , )
78
87
79
- NODE_CLASS_MAPPINGS = {
80
- "HyperTile" : HyperTile ,
81
- }
88
+
89
+ class HyperTileExtension (ComfyExtension ):
90
+ @override
91
+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
92
+ return [
93
+ HyperTile ,
94
+ ]
95
+
96
+
97
+ async def comfy_entrypoint () -> HyperTileExtension :
98
+ return HyperTileExtension ()
0 commit comments