Skip to content

Commit 76eb1d7

Browse files
authored
convert nodes_rebatch.py to V3 schema (#9945)
1 parent c4a46e9 commit 76eb1d7

File tree

1 file changed

+55
-40
lines changed

1 file changed

+55
-40
lines changed

comfy_extras/nodes_rebatch.py

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1+
from typing_extensions import override
12
import torch
23

3-
class LatentRebatch:
4-
@classmethod
5-
def INPUT_TYPES(s):
6-
return {"required": { "latents": ("LATENT",),
7-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
8-
}}
9-
RETURN_TYPES = ("LATENT",)
10-
INPUT_IS_LIST = True
11-
OUTPUT_IS_LIST = (True, )
4+
from comfy_api.latest import ComfyExtension, io
125

13-
FUNCTION = "rebatch"
146

15-
CATEGORY = "latent/batch"
7+
class LatentRebatch(io.ComfyNode):
8+
@classmethod
9+
def define_schema(cls):
10+
return io.Schema(
11+
node_id="RebatchLatents",
12+
display_name="Rebatch Latents",
13+
category="latent/batch",
14+
is_input_list=True,
15+
inputs=[
16+
io.Latent.Input("latents"),
17+
io.Int.Input("batch_size", default=1, min=1, max=4096),
18+
],
19+
outputs=[
20+
io.Latent.Output(is_output_list=True),
21+
],
22+
)
1623

1724
@staticmethod
1825
def get_batch(latents, list_ind, offset):
@@ -53,7 +60,8 @@ def cat_batch(batch1, batch2):
5360
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
5461
return result
5562

56-
def rebatch(self, latents, batch_size):
63+
@classmethod
64+
def execute(cls, latents, batch_size):
5765
batch_size = batch_size[0]
5866

5967
output_list = []
@@ -63,24 +71,24 @@ def rebatch(self, latents, batch_size):
6371
for i in range(len(latents)):
6472
# fetch new entry of list
6573
#samples, masks, indices = self.get_batch(latents, i)
66-
next_batch = self.get_batch(latents, i, processed)
74+
next_batch = cls.get_batch(latents, i, processed)
6775
processed += len(next_batch[2])
6876
# set to current if current is None
6977
if current_batch[0] is None:
7078
current_batch = next_batch
7179
# add previous to list if dimensions do not match
7280
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
73-
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
81+
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
7482
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
7583
current_batch = next_batch
7684
# cat if everything checks out
7785
else:
78-
current_batch = self.cat_batch(current_batch, next_batch)
86+
current_batch = cls.cat_batch(current_batch, next_batch)
7987

8088
# add to list if dimensions gone above target batch size
8189
if current_batch[0].shape[0] > batch_size:
8290
num = current_batch[0].shape[0] // batch_size
83-
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
91+
sliced, remainder = cls.slice_batch(current_batch, num, batch_size)
8492

8593
for i in range(num):
8694
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
@@ -89,31 +97,35 @@ def rebatch(self, latents, batch_size):
8997

9098
#add remainder
9199
if current_batch[0] is not None:
92-
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
100+
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
93101
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
94102

95103
#get rid of empty masks
96104
for s in output_list:
97105
if s['noise_mask'].mean() == 1.0:
98106
del s['noise_mask']
99107

100-
return (output_list,)
108+
return io.NodeOutput(output_list)
101109

102-
class ImageRebatch:
110+
class ImageRebatch(io.ComfyNode):
103111
@classmethod
104-
def INPUT_TYPES(s):
105-
return {"required": { "images": ("IMAGE",),
106-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
107-
}}
108-
RETURN_TYPES = ("IMAGE",)
109-
INPUT_IS_LIST = True
110-
OUTPUT_IS_LIST = (True, )
111-
112-
FUNCTION = "rebatch"
112+
def define_schema(cls):
113+
return io.Schema(
114+
node_id="RebatchImages",
115+
display_name="Rebatch Images",
116+
category="image/batch",
117+
is_input_list=True,
118+
inputs=[
119+
io.Image.Input("images"),
120+
io.Int.Input("batch_size", default=1, min=1, max=4096),
121+
],
122+
outputs=[
123+
io.Image.Output(is_output_list=True),
124+
],
125+
)
113126

114-
CATEGORY = "image/batch"
115-
116-
def rebatch(self, images, batch_size):
127+
@classmethod
128+
def execute(cls, images, batch_size):
117129
batch_size = batch_size[0]
118130

119131
output_list = []
@@ -125,14 +137,17 @@ def rebatch(self, images, batch_size):
125137
for i in range(0, len(all_images), batch_size):
126138
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
127139

128-
return (output_list,)
140+
return io.NodeOutput(output_list)
141+
142+
143+
class RebatchExtension(ComfyExtension):
144+
@override
145+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
146+
return [
147+
LatentRebatch,
148+
ImageRebatch,
149+
]
129150

130-
NODE_CLASS_MAPPINGS = {
131-
"RebatchLatents": LatentRebatch,
132-
"RebatchImages": ImageRebatch,
133-
}
134151

135-
NODE_DISPLAY_NAME_MAPPINGS = {
136-
"RebatchLatents": "Rebatch Latents",
137-
"RebatchImages": "Rebatch Images",
138-
}
152+
async def comfy_entrypoint() -> RebatchExtension:
153+
return RebatchExtension()

0 commit comments

Comments
 (0)