@@ -75,7 +75,7 @@ def move_to(self, device):
7575 return self .passive_memory_usage ()
7676
7777
78- def load_and_process_images (image_files , input_dir , resize_method = "None" ):
78+ def load_and_process_images (image_files , input_dir , resize_method = "None" , w = None , h = None ):
7979 """Utility function to load and process a list of images.
8080
8181 Args:
@@ -90,7 +90,6 @@ def load_and_process_images(image_files, input_dir, resize_method="None"):
9090 raise ValueError ("No valid images found in input" )
9191
9292 output_images = []
93- w , h = None , None
9493
9594 for file in image_files :
9695 image_path = os .path .join (input_dir , file )
@@ -206,6 +205,103 @@ def load_images(self, folder, resize_method):
206205 return (output_tensor ,)
207206
208207
208+ class LoadImageTextSetFromFolderNode :
209+ @classmethod
210+ def INPUT_TYPES (s ):
211+ return {
212+ "required" : {
213+ "folder" : (folder_paths .get_input_subfolders (), {"tooltip" : "The folder to load images from." }),
214+ "clip" : (IO .CLIP , {"tooltip" : "The CLIP model used for encoding the text." }),
215+ },
216+ "optional" : {
217+ "resize_method" : (
218+ ["None" , "Stretch" , "Crop" , "Pad" ],
219+ {"default" : "None" },
220+ ),
221+ "width" : (
222+ IO .INT ,
223+ {
224+ "default" : - 1 ,
225+ "min" : - 1 ,
226+ "max" : 10000 ,
227+ "step" : 1 ,
228+ "tooltip" : "The width to resize the images to. -1 means use the original width." ,
229+ },
230+ ),
231+ "height" : (
232+ IO .INT ,
233+ {
234+ "default" : - 1 ,
235+ "min" : - 1 ,
236+ "max" : 10000 ,
237+ "step" : 1 ,
238+ "tooltip" : "The height to resize the images to. -1 means use the original height." ,
239+ },
240+ )
241+ },
242+ }
243+
244+ RETURN_TYPES = ("IMAGE" , IO .CONDITIONING ,)
245+ FUNCTION = "load_images"
246+ CATEGORY = "loaders"
247+ EXPERIMENTAL = True
248+ DESCRIPTION = "Loads a batch of images and caption from a directory for training."
249+
250+ def load_images (self , folder , clip , resize_method , width = None , height = None ):
251+ if clip is None :
252+ raise RuntimeError ("ERROR: clip input is invalid: None\n \n If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model." )
253+
254+ logging .info (f"Loading images from folder: { folder } " )
255+
256+ sub_input_dir = os .path .join (folder_paths .get_input_directory (), folder )
257+ valid_extensions = [".png" , ".jpg" , ".jpeg" , ".webp" ]
258+
259+ image_files = []
260+ for item in os .listdir (sub_input_dir ):
261+ path = os .path .join (sub_input_dir , item )
262+ if any (item .lower ().endswith (ext ) for ext in valid_extensions ):
263+ image_files .append (path )
264+ elif os .path .isdir (path ):
265+ # Support kohya-ss/sd-scripts folder structure
266+ repeat = 1
267+ if item .split ("_" )[0 ].isdigit ():
268+ repeat = int (item .split ("_" )[0 ])
269+ image_files .extend ([
270+ os .path .join (path , f ) for f in os .listdir (path ) if any (f .lower ().endswith (ext ) for ext in valid_extensions )
271+ ] * repeat )
272+
273+ caption_file_path = [
274+ f .replace (os .path .splitext (f )[1 ], ".txt" )
275+ for f in image_files
276+ ]
277+ captions = []
278+ for caption_file in caption_file_path :
279+ caption_path = os .path .join (sub_input_dir , caption_file )
280+ if os .path .exists (caption_path ):
281+ with open (caption_path , "r" , encoding = "utf-8" ) as f :
282+ caption = f .read ().strip ()
283+ captions .append (caption )
284+ else :
285+ captions .append ("" )
286+
287+ width = width if width != - 1 else None
288+ height = height if height != - 1 else None
289+ output_tensor = load_and_process_images (image_files , sub_input_dir , resize_method , width , height )
290+
291+ logging .info (f"Loaded { len (output_tensor )} images from { sub_input_dir } ." )
292+
293+ logging .info (f"Encoding captions from { sub_input_dir } ." )
294+ conditions = []
295+ empty_cond = clip .encode_from_tokens_scheduled (clip .tokenize ("" ))
296+ for text in captions :
297+ if text == "" :
298+ conditions .append (empty_cond )
299+ tokens = clip .tokenize (text )
300+ conditions .extend (clip .encode_from_tokens_scheduled (tokens ))
301+ logging .info (f"Encoded { len (conditions )} captions from { sub_input_dir } ." )
302+ return (output_tensor , conditions )
303+
304+
209305def draw_loss_graph (loss_map , steps ):
210306 width , height = 500 , 300
211307 img = Image .new ("RGB" , (width , height ), "white" )
@@ -381,6 +477,13 @@ def train(
381477
382478 latents = latents ["samples" ].to (dtype )
383479 num_images = latents .shape [0 ]
480+ logging .info (f"Total Images: { num_images } , Total Captions: { len (positive )} " )
481+ if len (positive ) == 1 and num_images > 1 :
482+ positive = positive * num_images
483+ elif len (positive ) != num_images :
484+ raise ValueError (
485+ f"Number of positive conditions ({ len (positive )} ) does not match number of images ({ num_images } )."
486+ )
384487
385488 with torch .inference_mode (False ):
386489 lora_sd = {}
@@ -474,6 +577,7 @@ def train(
474577 # setup models
475578 for m in find_all_highest_child_module_with_forward (mp .model .diffusion_model ):
476579 patch (m )
580+ mp .model .requires_grad_ (False )
477581 comfy .model_management .load_models_gpu ([mp ], memory_required = 1e20 , force_full_load = True )
478582
479583 # Setup sampler and guider like in test script
@@ -486,7 +590,6 @@ def loss_callback(loss):
486590 )
487591 guider = comfy_extras .nodes_custom_sampler .Guider_Basic (mp )
488592 guider .set_conds (positive ) # Set conditioning from input
489- ss = comfy_extras .nodes_custom_sampler .SamplerCustomAdvanced ()
490593
491594 # yoland: this currently resize to the first image in the dataset
492595
@@ -495,21 +598,21 @@ def loss_callback(loss):
495598 try :
496599 for step in (pbar := tqdm .trange (steps , desc = "Training LoRA" , smoothing = 0.01 , disable = not comfy .utils .PROGRESS_BAR_ENABLED )):
497600 # Generate random sigma
498- sigma = mp .model .model_sampling .percent_to_sigma (
601+ sigmas = [ mp .model .model_sampling .percent_to_sigma (
499602 torch .rand ((1 ,)).item ()
500- )
501- sigma = torch .tensor ([ sigma ] )
603+ ) for _ in range ( min ( batch_size , num_images ))]
604+ sigmas = torch .tensor (sigmas )
502605
503606 noise = comfy_extras .nodes_custom_sampler .Noise_RandomNoise (step * 1000 + seed )
504607
505608 indices = torch .randperm (num_images )[:batch_size ]
506- ss . sample (
507- noise , guider , train_sampler , sigma , { "samples" : latents [ indices ]. clone ()}
508- )
609+ batch_latent = latents [ indices ]. clone ()
610+ guider . set_conds ([ positive [ i ] for i in indices ]) # Set conditioning from input
611+ guider . sample ( noise . generate_noise ({ "samples" : batch_latent }), batch_latent , train_sampler , sigmas , seed = noise . seed )
509612 finally :
510613 for m in mp .model .modules ():
511614 unpatch (m )
512- del ss , train_sampler , optimizer
615+ del train_sampler , optimizer
513616 torch .cuda .empty_cache ()
514617
515618 for adapter in all_weight_adapters :
@@ -697,6 +800,7 @@ def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
697800 "SaveLoRANode" : SaveLoRA ,
698801 "LoraModelLoader" : LoraModelLoader ,
699802 "LoadImageSetFromFolderNode" : LoadImageSetFromFolderNode ,
803+ "LoadImageTextSetFromFolderNode" : LoadImageTextSetFromFolderNode ,
700804 "LossGraphNode" : LossGraphNode ,
701805}
702806
@@ -705,5 +809,6 @@ def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
705809 "SaveLoRANode" : "Save LoRA Weights" ,
706810 "LoraModelLoader" : "Load LoRA Model" ,
707811 "LoadImageSetFromFolderNode" : "Load Image Dataset from Folder" ,
812+ "LoadImageTextSetFromFolderNode" : "Load Image and Text Dataset from Folder" ,
708813 "LossGraphNode" : "Plot Loss Graph" ,
709814}
0 commit comments