Skip to content

Commit 4edde13

Browse files
authored
[SD3 training] refactor the density and weighting utilities. (#8591)
refactor the density and weighting utilities.
1 parent 074a7cc commit 4edde13

File tree

3 files changed

+89
-48
lines changed

3 files changed

+89
-48
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@
5353
StableDiffusion3Pipeline,
5454
)
5555
from diffusers.optimization import get_scheduler
56-
from diffusers.training_utils import cast_training_params
56+
from diffusers.training_utils import (
57+
cast_training_params,
58+
compute_density_for_timestep_sampling,
59+
compute_loss_weighting_for_sd3,
60+
)
5761
from diffusers.utils import (
5862
check_min_version,
5963
convert_unet_state_dict_to_peft,
@@ -473,11 +477,20 @@ def parse_args(input_args=None):
473477
),
474478
)
475479
parser.add_argument(
476-
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
480+
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
481+
)
482+
parser.add_argument(
483+
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
484+
)
485+
parser.add_argument(
486+
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
487+
)
488+
parser.add_argument(
489+
"--mode_scale",
490+
type=float,
491+
default=1.29,
492+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
477493
)
478-
parser.add_argument("--logit_mean", type=float, default=0.0)
479-
parser.add_argument("--logit_std", type=float, default=1.0)
480-
parser.add_argument("--mode_scale", type=float, default=1.29)
481494
parser.add_argument(
482495
"--optimizer",
483496
type=str,
@@ -1477,16 +1490,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14771490

14781491
# Sample a random timestep for each image
14791492
# for weighting schemes where we sample timesteps non-uniformly
1480-
if args.weighting_scheme == "logit_normal":
1481-
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1482-
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
1483-
u = torch.nn.functional.sigmoid(u)
1484-
elif args.weighting_scheme == "mode":
1485-
u = torch.rand(size=(bsz,), device="cpu")
1486-
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
1487-
else:
1488-
u = torch.rand(size=(bsz,), device="cpu")
1489-
1493+
u = compute_density_for_timestep_sampling(
1494+
weighting_scheme=args.weighting_scheme,
1495+
batch_size=bsz,
1496+
logit_mean=args.logit_mean,
1497+
logit_std=args.logit_std,
1498+
mode_scale=args.mode_scale,
1499+
)
14901500
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
14911501
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
14921502

@@ -1507,19 +1517,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15071517
# Preconditioning of the model outputs.
15081518
model_pred = model_pred * (-sigmas) + noisy_model_input
15091519

1510-
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
15111520
# these weighting schemes use a uniform timestep sampling
15121521
# and instead post-weight the loss
1513-
if args.weighting_scheme == "sigma_sqrt":
1514-
weighting = (sigmas**-2.0).float()
1515-
elif args.weighting_scheme == "cosmap":
1516-
bot = 1 - 2 * sigmas + 2 * sigmas**2
1517-
weighting = 2 / (math.pi * bot)
1518-
else:
1519-
weighting = torch.ones_like(sigmas)
1522+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
15201523

1521-
# simplified flow matching aka 0-rectified flow matching loss
1522-
# target = model_input - noise
1524+
# flow matching loss
15231525
target = model_input
15241526

15251527
if args.with_prior_preservation:

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
StableDiffusion3Pipeline,
5252
)
5353
from diffusers.optimization import get_scheduler
54+
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
5455
from diffusers.utils import (
5556
check_min_version,
5657
is_wandb_available,
@@ -471,11 +472,20 @@ def parse_args(input_args=None):
471472
),
472473
)
473474
parser.add_argument(
474-
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
475+
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
476+
)
477+
parser.add_argument(
478+
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
479+
)
480+
parser.add_argument(
481+
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
482+
)
483+
parser.add_argument(
484+
"--mode_scale",
485+
type=float,
486+
default=1.29,
487+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
475488
)
476-
parser.add_argument("--logit_mean", type=float, default=0.0)
477-
parser.add_argument("--logit_std", type=float, default=1.0)
478-
parser.add_argument("--mode_scale", type=float, default=1.29)
479489
parser.add_argument(
480490
"--optimizer",
481491
type=str,
@@ -1541,16 +1551,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15411551

15421552
# Sample a random timestep for each image
15431553
# for weighting schemes where we sample timesteps non-uniformly
1544-
if args.weighting_scheme == "logit_normal":
1545-
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1546-
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
1547-
u = torch.nn.functional.sigmoid(u)
1548-
elif args.weighting_scheme == "mode":
1549-
u = torch.rand(size=(bsz,), device="cpu")
1550-
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
1551-
else:
1552-
u = torch.rand(size=(bsz,), device="cpu")
1553-
1554+
u = compute_density_for_timestep_sampling(
1555+
weighting_scheme=args.weighting_scheme,
1556+
batch_size=bsz,
1557+
logit_mean=args.logit_mean,
1558+
logit_std=args.logit_std,
1559+
mode_scale=args.mode_scale,
1560+
)
15541561
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
15551562
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
15561563

@@ -1587,16 +1594,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15871594
model_pred = model_pred * (-sigmas) + noisy_model_input
15881595
# these weighting schemes use a uniform timestep sampling
15891596
# and instead post-weight the loss
1590-
if args.weighting_scheme == "sigma_sqrt":
1591-
weighting = (sigmas**-2.0).float()
1592-
elif args.weighting_scheme == "cosmap":
1593-
bot = 1 - 2 * sigmas + 2 * sigmas**2
1594-
weighting = 2 / (math.pi * bot)
1595-
else:
1596-
weighting = torch.ones_like(sigmas)
1597+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
15971598

1598-
# simplified flow matching aka 0-rectified flow matching loss
1599-
# target = model_input - noise
1599+
# flow matching loss
16001600
target = model_input
16011601

16021602
if args.with_prior_preservation:

src/diffusers/training_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import copy
3+
import math
34
import random
45
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
56

@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
220221
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
221222

222223

224+
def compute_density_for_timestep_sampling(
225+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
226+
):
227+
"""Compute the density for sampling the timesteps when doing SD3 training.
228+
229+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
230+
231+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
232+
"""
233+
if weighting_scheme == "logit_normal":
234+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
235+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
236+
u = torch.nn.functional.sigmoid(u)
237+
elif weighting_scheme == "mode":
238+
u = torch.rand(size=(batch_size,), device="cpu")
239+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
240+
else:
241+
u = torch.rand(size=(batch_size,), device="cpu")
242+
return u
243+
244+
245+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
246+
"""Computes loss weighting scheme for SD3 training.
247+
248+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
249+
250+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
251+
"""
252+
if weighting_scheme == "sigma_sqrt":
253+
weighting = (sigmas**-2.0).float()
254+
elif weighting_scheme == "cosmap":
255+
bot = 1 - 2 * sigmas + 2 * sigmas**2
256+
weighting = 2 / (math.pi * bot)
257+
else:
258+
weighting = torch.ones_like(sigmas)
259+
return weighting
260+
261+
223262
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
224263
class EMAModel:
225264
"""

0 commit comments

Comments
 (0)