Skip to content

Commit 181a9bf

Browse files
Support Multi Image-Caption dataset in lora training node (Comfy-Org#8819)
* initial impl of multi img/text dataset * Update nodes_train.py * Support Kohya-ss structure
1 parent aac10ad commit 181a9bf

File tree

1 file changed

+115
-10
lines changed

1 file changed

+115
-10
lines changed

comfy_extras/nodes_train.py

Lines changed: 115 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def move_to(self, device):
7575
return self.passive_memory_usage()
7676

7777

78-
def load_and_process_images(image_files, input_dir, resize_method="None"):
78+
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
7979
"""Utility function to load and process a list of images.
8080
8181
Args:
@@ -90,7 +90,6 @@ def load_and_process_images(image_files, input_dir, resize_method="None"):
9090
raise ValueError("No valid images found in input")
9191

9292
output_images = []
93-
w, h = None, None
9493

9594
for file in image_files:
9695
image_path = os.path.join(input_dir, file)
@@ -206,6 +205,103 @@ def load_images(self, folder, resize_method):
206205
return (output_tensor,)
207206

208207

208+
class LoadImageTextSetFromFolderNode:
209+
@classmethod
210+
def INPUT_TYPES(s):
211+
return {
212+
"required": {
213+
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
214+
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
215+
},
216+
"optional": {
217+
"resize_method": (
218+
["None", "Stretch", "Crop", "Pad"],
219+
{"default": "None"},
220+
),
221+
"width": (
222+
IO.INT,
223+
{
224+
"default": -1,
225+
"min": -1,
226+
"max": 10000,
227+
"step": 1,
228+
"tooltip": "The width to resize the images to. -1 means use the original width.",
229+
},
230+
),
231+
"height": (
232+
IO.INT,
233+
{
234+
"default": -1,
235+
"min": -1,
236+
"max": 10000,
237+
"step": 1,
238+
"tooltip": "The height to resize the images to. -1 means use the original height.",
239+
},
240+
)
241+
},
242+
}
243+
244+
RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
245+
FUNCTION = "load_images"
246+
CATEGORY = "loaders"
247+
EXPERIMENTAL = True
248+
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
249+
250+
def load_images(self, folder, clip, resize_method, width=None, height=None):
251+
if clip is None:
252+
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
253+
254+
logging.info(f"Loading images from folder: {folder}")
255+
256+
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
257+
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
258+
259+
image_files = []
260+
for item in os.listdir(sub_input_dir):
261+
path = os.path.join(sub_input_dir, item)
262+
if any(item.lower().endswith(ext) for ext in valid_extensions):
263+
image_files.append(path)
264+
elif os.path.isdir(path):
265+
# Support kohya-ss/sd-scripts folder structure
266+
repeat = 1
267+
if item.split("_")[0].isdigit():
268+
repeat = int(item.split("_")[0])
269+
image_files.extend([
270+
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
271+
] * repeat)
272+
273+
caption_file_path = [
274+
f.replace(os.path.splitext(f)[1], ".txt")
275+
for f in image_files
276+
]
277+
captions = []
278+
for caption_file in caption_file_path:
279+
caption_path = os.path.join(sub_input_dir, caption_file)
280+
if os.path.exists(caption_path):
281+
with open(caption_path, "r", encoding="utf-8") as f:
282+
caption = f.read().strip()
283+
captions.append(caption)
284+
else:
285+
captions.append("")
286+
287+
width = width if width != -1 else None
288+
height = height if height != -1 else None
289+
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
290+
291+
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
292+
293+
logging.info(f"Encoding captions from {sub_input_dir}.")
294+
conditions = []
295+
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
296+
for text in captions:
297+
if text == "":
298+
conditions.append(empty_cond)
299+
tokens = clip.tokenize(text)
300+
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
301+
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
302+
return (output_tensor, conditions)
303+
304+
209305
def draw_loss_graph(loss_map, steps):
210306
width, height = 500, 300
211307
img = Image.new("RGB", (width, height), "white")
@@ -381,6 +477,13 @@ def train(
381477

382478
latents = latents["samples"].to(dtype)
383479
num_images = latents.shape[0]
480+
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
481+
if len(positive) == 1 and num_images > 1:
482+
positive = positive * num_images
483+
elif len(positive) != num_images:
484+
raise ValueError(
485+
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
486+
)
384487

385488
with torch.inference_mode(False):
386489
lora_sd = {}
@@ -474,6 +577,7 @@ def train(
474577
# setup models
475578
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
476579
patch(m)
580+
mp.model.requires_grad_(False)
477581
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
478582

479583
# Setup sampler and guider like in test script
@@ -486,7 +590,6 @@ def loss_callback(loss):
486590
)
487591
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
488592
guider.set_conds(positive) # Set conditioning from input
489-
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
490593

491594
# yoland: this currently resize to the first image in the dataset
492595

@@ -495,21 +598,21 @@ def loss_callback(loss):
495598
try:
496599
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
497600
# Generate random sigma
498-
sigma = mp.model.model_sampling.percent_to_sigma(
601+
sigmas = [mp.model.model_sampling.percent_to_sigma(
499602
torch.rand((1,)).item()
500-
)
501-
sigma = torch.tensor([sigma])
603+
) for _ in range(min(batch_size, num_images))]
604+
sigmas = torch.tensor(sigmas)
502605

503606
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
504607

505608
indices = torch.randperm(num_images)[:batch_size]
506-
ss.sample(
507-
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
508-
)
609+
batch_latent = latents[indices].clone()
610+
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
611+
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
509612
finally:
510613
for m in mp.model.modules():
511614
unpatch(m)
512-
del ss, train_sampler, optimizer
615+
del train_sampler, optimizer
513616
torch.cuda.empty_cache()
514617

515618
for adapter in all_weight_adapters:
@@ -697,6 +800,7 @@ def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
697800
"SaveLoRANode": SaveLoRA,
698801
"LoraModelLoader": LoraModelLoader,
699802
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
803+
"LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
700804
"LossGraphNode": LossGraphNode,
701805
}
702806

@@ -705,5 +809,6 @@ def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
705809
"SaveLoRANode": "Save LoRA Weights",
706810
"LoraModelLoader": "Load LoRA Model",
707811
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
812+
"LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
708813
"LossGraphNode": "Plot Loss Graph",
709814
}

0 commit comments

Comments
 (0)