1
+ from typing_extensions import override
1
2
import torch
2
3
3
- class LatentRebatch :
4
- @classmethod
5
- def INPUT_TYPES (s ):
6
- return {"required" : { "latents" : ("LATENT" ,),
7
- "batch_size" : ("INT" , {"default" : 1 , "min" : 1 , "max" : 4096 }),
8
- }}
9
- RETURN_TYPES = ("LATENT" ,)
10
- INPUT_IS_LIST = True
11
- OUTPUT_IS_LIST = (True , )
4
+ from comfy_api .latest import ComfyExtension , io
12
5
13
- FUNCTION = "rebatch"
14
6
15
- CATEGORY = "latent/batch"
7
+ class LatentRebatch (io .ComfyNode ):
8
+ @classmethod
9
+ def define_schema (cls ):
10
+ return io .Schema (
11
+ node_id = "RebatchLatents" ,
12
+ display_name = "Rebatch Latents" ,
13
+ category = "latent/batch" ,
14
+ is_input_list = True ,
15
+ inputs = [
16
+ io .Latent .Input ("latents" ),
17
+ io .Int .Input ("batch_size" , default = 1 , min = 1 , max = 4096 ),
18
+ ],
19
+ outputs = [
20
+ io .Latent .Output (is_output_list = True ),
21
+ ],
22
+ )
16
23
17
24
@staticmethod
18
25
def get_batch (latents , list_ind , offset ):
@@ -53,7 +60,8 @@ def cat_batch(batch1, batch2):
53
60
result = [torch .cat ((b1 , b2 )) if torch .is_tensor (b1 ) else b1 + b2 for b1 , b2 in zip (batch1 , batch2 )]
54
61
return result
55
62
56
- def rebatch (self , latents , batch_size ):
63
+ @classmethod
64
+ def execute (cls , latents , batch_size ):
57
65
batch_size = batch_size [0 ]
58
66
59
67
output_list = []
@@ -63,24 +71,24 @@ def rebatch(self, latents, batch_size):
63
71
for i in range (len (latents )):
64
72
# fetch new entry of list
65
73
#samples, masks, indices = self.get_batch(latents, i)
66
- next_batch = self .get_batch (latents , i , processed )
74
+ next_batch = cls .get_batch (latents , i , processed )
67
75
processed += len (next_batch [2 ])
68
76
# set to current if current is None
69
77
if current_batch [0 ] is None :
70
78
current_batch = next_batch
71
79
# add previous to list if dimensions do not match
72
80
elif next_batch [0 ].shape [- 1 ] != current_batch [0 ].shape [- 1 ] or next_batch [0 ].shape [- 2 ] != current_batch [0 ].shape [- 2 ]:
73
- sliced , _ = self .slice_batch (current_batch , 1 , batch_size )
81
+ sliced , _ = cls .slice_batch (current_batch , 1 , batch_size )
74
82
output_list .append ({'samples' : sliced [0 ][0 ], 'noise_mask' : sliced [1 ][0 ], 'batch_index' : sliced [2 ][0 ]})
75
83
current_batch = next_batch
76
84
# cat if everything checks out
77
85
else :
78
- current_batch = self .cat_batch (current_batch , next_batch )
86
+ current_batch = cls .cat_batch (current_batch , next_batch )
79
87
80
88
# add to list if dimensions gone above target batch size
81
89
if current_batch [0 ].shape [0 ] > batch_size :
82
90
num = current_batch [0 ].shape [0 ] // batch_size
83
- sliced , remainder = self .slice_batch (current_batch , num , batch_size )
91
+ sliced , remainder = cls .slice_batch (current_batch , num , batch_size )
84
92
85
93
for i in range (num ):
86
94
output_list .append ({'samples' : sliced [0 ][i ], 'noise_mask' : sliced [1 ][i ], 'batch_index' : sliced [2 ][i ]})
@@ -89,31 +97,35 @@ def rebatch(self, latents, batch_size):
89
97
90
98
#add remainder
91
99
if current_batch [0 ] is not None :
92
- sliced , _ = self .slice_batch (current_batch , 1 , batch_size )
100
+ sliced , _ = cls .slice_batch (current_batch , 1 , batch_size )
93
101
output_list .append ({'samples' : sliced [0 ][0 ], 'noise_mask' : sliced [1 ][0 ], 'batch_index' : sliced [2 ][0 ]})
94
102
95
103
#get rid of empty masks
96
104
for s in output_list :
97
105
if s ['noise_mask' ].mean () == 1.0 :
98
106
del s ['noise_mask' ]
99
107
100
- return (output_list , )
108
+ return io . NodeOutput (output_list )
101
109
102
- class ImageRebatch :
110
+ class ImageRebatch ( io . ComfyNode ) :
103
111
@classmethod
104
- def INPUT_TYPES (s ):
105
- return {"required" : { "images" : ("IMAGE" ,),
106
- "batch_size" : ("INT" , {"default" : 1 , "min" : 1 , "max" : 4096 }),
107
- }}
108
- RETURN_TYPES = ("IMAGE" ,)
109
- INPUT_IS_LIST = True
110
- OUTPUT_IS_LIST = (True , )
111
-
112
- FUNCTION = "rebatch"
112
+ def define_schema (cls ):
113
+ return io .Schema (
114
+ node_id = "RebatchImages" ,
115
+ display_name = "Rebatch Images" ,
116
+ category = "image/batch" ,
117
+ is_input_list = True ,
118
+ inputs = [
119
+ io .Image .Input ("images" ),
120
+ io .Int .Input ("batch_size" , default = 1 , min = 1 , max = 4096 ),
121
+ ],
122
+ outputs = [
123
+ io .Image .Output (is_output_list = True ),
124
+ ],
125
+ )
113
126
114
- CATEGORY = "image/batch"
115
-
116
- def rebatch (self , images , batch_size ):
127
+ @classmethod
128
+ def execute (cls , images , batch_size ):
117
129
batch_size = batch_size [0 ]
118
130
119
131
output_list = []
@@ -125,14 +137,17 @@ def rebatch(self, images, batch_size):
125
137
for i in range (0 , len (all_images ), batch_size ):
126
138
output_list .append (torch .cat (all_images [i :i + batch_size ], dim = 0 ))
127
139
128
- return (output_list ,)
140
+ return io .NodeOutput (output_list )
141
+
142
+
143
+ class RebatchExtension (ComfyExtension ):
144
+ @override
145
+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
146
+ return [
147
+ LatentRebatch ,
148
+ ImageRebatch ,
149
+ ]
129
150
130
- NODE_CLASS_MAPPINGS = {
131
- "RebatchLatents" : LatentRebatch ,
132
- "RebatchImages" : ImageRebatch ,
133
- }
134
151
135
- NODE_DISPLAY_NAME_MAPPINGS = {
136
- "RebatchLatents" : "Rebatch Latents" ,
137
- "RebatchImages" : "Rebatch Images" ,
138
- }
152
+ async def comfy_entrypoint () -> RebatchExtension :
153
+ return RebatchExtension ()
0 commit comments