@@ -425,6 +425,11 @@ def parse_args(input_args=None):
425425 default = 4 ,
426426 help = ("The dimension of the LoRA update matrices." ),
427427 )
428+ parser .add_argument (
429+ "--debug_loss" ,
430+ action = "store_true" ,
431+ help = "debug loss for each image, if filenames are awailable in the dataset" ,
432+ )
428433
429434 if input_args is not None :
430435 args = parser .parse_args (input_args )
@@ -603,6 +608,7 @@ def main(args):
603608 # Move unet, vae and text_encoder to device and cast to weight_dtype
604609 # The VAE is in float32 to avoid NaN losses.
605610 unet .to (accelerator .device , dtype = weight_dtype )
611+
606612 if args .pretrained_vae_model_name_or_path is None :
607613 vae .to (accelerator .device , dtype = torch .float32 )
608614 else :
@@ -890,13 +896,17 @@ def preprocess_train(examples):
890896 tokens_one , tokens_two = tokenize_captions (examples )
891897 examples ["input_ids_one" ] = tokens_one
892898 examples ["input_ids_two" ] = tokens_two
899+ if args .debug_loss :
900+ fnames = [os .path .basename (image .filename ) for image in examples [image_column ] if image .filename ]
901+ if fnames :
902+ examples ["filenames" ] = fnames
893903 return examples
894904
895905 with accelerator .main_process_first ():
896906 if args .max_train_samples is not None :
897907 dataset ["train" ] = dataset ["train" ].shuffle (seed = args .seed ).select (range (args .max_train_samples ))
898908 # Set the training transforms
899- train_dataset = dataset ["train" ].with_transform (preprocess_train )
909+ train_dataset = dataset ["train" ].with_transform (preprocess_train , output_all_columns = True )
900910
901911 def collate_fn (examples ):
902912 pixel_values = torch .stack ([example ["pixel_values" ] for example in examples ])
@@ -905,14 +915,19 @@ def collate_fn(examples):
905915 crop_top_lefts = [example ["crop_top_lefts" ] for example in examples ]
906916 input_ids_one = torch .stack ([example ["input_ids_one" ] for example in examples ])
907917 input_ids_two = torch .stack ([example ["input_ids_two" ] for example in examples ])
908- return {
918+ result = {
909919 "pixel_values" : pixel_values ,
910920 "input_ids_one" : input_ids_one ,
911921 "input_ids_two" : input_ids_two ,
912922 "original_sizes" : original_sizes ,
913923 "crop_top_lefts" : crop_top_lefts ,
914924 }
915925
926+ filenames = [example ["filenames" ] for example in examples if "filenames" in example ]
927+ if filenames :
928+ result ["filenames" ] = filenames
929+ return result
930+
916931 # DataLoaders creation:
917932 train_dataloader = torch .utils .data .DataLoader (
918933 train_dataset ,
@@ -1105,7 +1120,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
11051120 loss = F .mse_loss (model_pred .float (), target .float (), reduction = "none" )
11061121 loss = loss .mean (dim = list (range (1 , len (loss .shape )))) * mse_loss_weights
11071122 loss = loss .mean ()
1108-
1123+ if args .debug_loss and "filenames" in batch :
1124+ for fname in batch ["filenames" ]:
1125+ accelerator .log ({"loss_for_" + fname : loss }, step = global_step )
11091126 # Gather the losses across all processes for logging (if we use distributed training).
11101127 avg_loss = accelerator .gather (loss .repeat (args .train_batch_size )).mean ()
11111128 train_loss += avg_loss .item () / args .gradient_accumulation_steps
0 commit comments