Skip to content

Commit 53e9aac

Browse files
noskillsayakpaul
andauthored
log loss per image (#7278)
* log loss per image * add commandline param for per image loss logging * style * debug-loss -> debug_loss --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 4142446 commit 53e9aac

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ def parse_args(input_args=None):
425425
default=4,
426426
help=("The dimension of the LoRA update matrices."),
427427
)
428+
parser.add_argument(
429+
"--debug_loss",
430+
action="store_true",
431+
help="debug loss for each image, if filenames are awailable in the dataset",
432+
)
428433

429434
if input_args is not None:
430435
args = parser.parse_args(input_args)
@@ -603,6 +608,7 @@ def main(args):
603608
# Move unet, vae and text_encoder to device and cast to weight_dtype
604609
# The VAE is in float32 to avoid NaN losses.
605610
unet.to(accelerator.device, dtype=weight_dtype)
611+
606612
if args.pretrained_vae_model_name_or_path is None:
607613
vae.to(accelerator.device, dtype=torch.float32)
608614
else:
@@ -890,13 +896,17 @@ def preprocess_train(examples):
890896
tokens_one, tokens_two = tokenize_captions(examples)
891897
examples["input_ids_one"] = tokens_one
892898
examples["input_ids_two"] = tokens_two
899+
if args.debug_loss:
900+
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
901+
if fnames:
902+
examples["filenames"] = fnames
893903
return examples
894904

895905
with accelerator.main_process_first():
896906
if args.max_train_samples is not None:
897907
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
898908
# Set the training transforms
899-
train_dataset = dataset["train"].with_transform(preprocess_train)
909+
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
900910

901911
def collate_fn(examples):
902912
pixel_values = torch.stack([example["pixel_values"] for example in examples])
@@ -905,14 +915,19 @@ def collate_fn(examples):
905915
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
906916
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
907917
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
908-
return {
918+
result = {
909919
"pixel_values": pixel_values,
910920
"input_ids_one": input_ids_one,
911921
"input_ids_two": input_ids_two,
912922
"original_sizes": original_sizes,
913923
"crop_top_lefts": crop_top_lefts,
914924
}
915925

926+
filenames = [example["filenames"] for example in examples if "filenames" in example]
927+
if filenames:
928+
result["filenames"] = filenames
929+
return result
930+
916931
# DataLoaders creation:
917932
train_dataloader = torch.utils.data.DataLoader(
918933
train_dataset,
@@ -1105,7 +1120,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
11051120
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
11061121
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
11071122
loss = loss.mean()
1108-
1123+
if args.debug_loss and "filenames" in batch:
1124+
for fname in batch["filenames"]:
1125+
accelerator.log({"loss_for_" + fname: loss}, step=global_step)
11091126
# Gather the losses across all processes for logging (if we use distributed training).
11101127
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
11111128
train_loss += avg_loss.item() / args.gradient_accumulation_steps

0 commit comments

Comments
 (0)