Skip to content

Commit c101066

Browse files
Correct SNR weighted loss in v-prediction case by only adding 1 to SNR on the denominator (#6307)
* fix minsnr implementation for v-prediction case * format code * always compute snr when snr_gamma is specified --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent d4c7ab7 commit c101066

File tree

10 files changed

+69
-58
lines changed

10 files changed

+69
-58
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -907,10 +907,12 @@ def compute_loss(params, minibatch, sample_rng):
907907

908908
if args.snr_gamma is not None:
909909
snr = jnp.array(compute_snr(timesteps))
910-
if noise_scheduler.config.prediction_type == "v_prediction":
911-
# Velocity objective requires that we add one to SNR values before we divide by them.
912-
snr = snr + 1
913-
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
910+
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)
911+
if noise_scheduler.config.prediction_type == "epsilon":
912+
snr_loss_weights = snr_loss_weights / snr
913+
elif noise_scheduler.config.prediction_type == "v_prediction":
914+
snr_loss_weights = snr_loss_weights / (snr + 1)
915+
914916
loss = loss * snr_loss_weights
915917

916918
loss = loss.mean()

examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -781,12 +781,13 @@ def collate_fn(examples):
781781
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
782782
# This is discussed in Section 4.2 of the same paper.
783783
snr = compute_snr(noise_scheduler, timesteps)
784-
if noise_scheduler.config.prediction_type == "v_prediction":
785-
# Velocity objective requires that we add one to SNR values before we divide by them.
786-
snr = snr + 1
787-
mse_loss_weights = (
788-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
789-
)
784+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
785+
dim=1
786+
)[0]
787+
if noise_scheduler.config.prediction_type == "epsilon":
788+
mse_loss_weights = mse_loss_weights / snr
789+
elif noise_scheduler.config.prediction_type == "v_prediction":
790+
mse_loss_weights = mse_loss_weights / (snr + 1)
790791

791792
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
792793
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -631,12 +631,13 @@ def collate_fn(examples):
631631
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
632632
# This is discussed in Section 4.2 of the same paper.
633633
snr = compute_snr(noise_scheduler, timesteps)
634-
if noise_scheduler.config.prediction_type == "v_prediction":
635-
# Velocity objective requires that we add one to SNR values before we divide by them.
636-
snr = snr + 1
637-
mse_loss_weights = (
638-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
639-
)
634+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
635+
dim=1
636+
)[0]
637+
if noise_scheduler.config.prediction_type == "epsilon":
638+
mse_loss_weights = mse_loss_weights / snr
639+
elif noise_scheduler.config.prediction_type == "v_prediction":
640+
mse_loss_weights = mse_loss_weights / (snr + 1)
640641

641642
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
642643
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -664,12 +664,13 @@ def collate_fn(examples):
664664
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
665665
# This is discussed in Section 4.2 of the same paper.
666666
snr = compute_snr(noise_scheduler, timesteps)
667-
if noise_scheduler.config.prediction_type == "v_prediction":
668-
# Velocity objective requires that we add one to SNR values before we divide by them.
669-
snr = snr + 1
670-
mse_loss_weights = (
671-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
672-
)
667+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
668+
dim=1
669+
)[0]
670+
if noise_scheduler.config.prediction_type == "epsilon":
671+
mse_loss_weights = mse_loss_weights / snr
672+
elif noise_scheduler.config.prediction_type == "v_prediction":
673+
mse_loss_weights = mse_loss_weights / (snr + 1)
673674

674675
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
675676
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -811,12 +811,13 @@ def collate_fn(examples):
811811
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
812812
# This is discussed in Section 4.2 of the same paper.
813813
snr = compute_snr(noise_scheduler, timesteps)
814-
if noise_scheduler.config.prediction_type == "v_prediction":
815-
# Velocity objective requires that we add one to SNR values before we divide by them.
816-
snr = snr + 1
817-
mse_loss_weights = (
818-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
819-
)
814+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
815+
dim=1
816+
)[0]
817+
if noise_scheduler.config.prediction_type == "epsilon":
818+
mse_loss_weights = mse_loss_weights / snr
819+
elif noise_scheduler.config.prediction_type == "v_prediction":
820+
mse_loss_weights = mse_loss_weights / (snr + 1)
820821

821822
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
822823
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -848,12 +848,13 @@ def collate_fn(examples):
848848
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
849849
# This is discussed in Section 4.2 of the same paper.
850850
snr = compute_snr(noise_scheduler, timesteps)
851-
if noise_scheduler.config.prediction_type == "v_prediction":
852-
# Velocity objective requires that we add one to SNR values before we divide by them.
853-
snr = snr + 1
854-
mse_loss_weights = (
855-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
856-
)
851+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
852+
dim=1
853+
)[0]
854+
if noise_scheduler.config.prediction_type == "epsilon":
855+
mse_loss_weights = mse_loss_weights / snr
856+
elif noise_scheduler.config.prediction_type == "v_prediction":
857+
mse_loss_weights = mse_loss_weights / (snr + 1)
857858

858859
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
859860
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/text_to_image/train_text_to_image.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -943,12 +943,13 @@ def unwrap_model(model):
943943
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
944944
# This is discussed in Section 4.2 of the same paper.
945945
snr = compute_snr(noise_scheduler, timesteps)
946-
if noise_scheduler.config.prediction_type == "v_prediction":
947-
# Velocity objective requires that we add one to SNR values before we divide by them.
948-
snr = snr + 1
949-
mse_loss_weights = (
950-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
951-
)
946+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
947+
dim=1
948+
)[0]
949+
if noise_scheduler.config.prediction_type == "epsilon":
950+
mse_loss_weights = mse_loss_weights / snr
951+
elif noise_scheduler.config.prediction_type == "v_prediction":
952+
mse_loss_weights = mse_loss_weights / (snr + 1)
952953

953954
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
954955
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -759,12 +759,13 @@ def collate_fn(examples):
759759
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
760760
# This is discussed in Section 4.2 of the same paper.
761761
snr = compute_snr(noise_scheduler, timesteps)
762-
if noise_scheduler.config.prediction_type == "v_prediction":
763-
# Velocity objective requires that we add one to SNR values before we divide by them.
764-
snr = snr + 1
765-
mse_loss_weights = (
766-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
767-
)
762+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
763+
dim=1
764+
)[0]
765+
if noise_scheduler.config.prediction_type == "epsilon":
766+
mse_loss_weights = mse_loss_weights / snr
767+
elif noise_scheduler.config.prediction_type == "v_prediction":
768+
mse_loss_weights = mse_loss_weights / (snr + 1)
768769

769770
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
770771
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,12 +1062,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
10621062
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
10631063
# This is discussed in Section 4.2 of the same paper.
10641064
snr = compute_snr(noise_scheduler, timesteps)
1065-
if noise_scheduler.config.prediction_type == "v_prediction":
1066-
# Velocity objective requires that we add one to SNR values before we divide by them.
1067-
snr = snr + 1
1068-
mse_loss_weights = (
1069-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1070-
)
1065+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1066+
dim=1
1067+
)[0]
1068+
if noise_scheduler.config.prediction_type == "epsilon":
1069+
mse_loss_weights = mse_loss_weights / snr
1070+
elif noise_scheduler.config.prediction_type == "v_prediction":
1071+
mse_loss_weights = mse_loss_weights / (snr + 1)
10711072

10721073
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
10731074
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,12 +1087,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
10871087
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
10881088
# This is discussed in Section 4.2 of the same paper.
10891089
snr = compute_snr(noise_scheduler, timesteps)
1090-
if noise_scheduler.config.prediction_type == "v_prediction":
1091-
# Velocity objective requires that we add one to SNR values before we divide by them.
1092-
snr = snr + 1
1093-
mse_loss_weights = (
1094-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1095-
)
1090+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1091+
dim=1
1092+
)[0]
1093+
if noise_scheduler.config.prediction_type == "epsilon":
1094+
mse_loss_weights = mse_loss_weights / snr
1095+
elif noise_scheduler.config.prediction_type == "v_prediction":
1096+
mse_loss_weights = mse_loss_weights / (snr + 1)
10961097

10971098
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
10981099
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights

0 commit comments

Comments
 (0)