Skip to content

Commit 6946fac

Browse files
Slickytailkashifsayakpaul
authored
Implement SD3 loss weighting (#8528)
* Add lognorm and cosmap weighting * Implement mode sampling * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * keep timestamp sampling fully on cpu --------- Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 130dd93 commit 6946fac

File tree

2 files changed

+38
-20
lines changed

2 files changed

+38
-20
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,7 +1462,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14621462
bsz = model_input.shape[0]
14631463

14641464
# Sample a random timestep for each image
1465-
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
1465+
# for weighting schemes where we sample timesteps non-uniformly
1466+
if args.weighting_scheme == "logit_normal":
1467+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1468+
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
1469+
u = torch.nn.functional.sigmoid(u)
1470+
elif args.weighting_scheme == "mode":
1471+
u = torch.rand(size=(bsz,), device="cpu")
1472+
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
1473+
else:
1474+
u = torch.rand(size=(bsz,), device="cpu")
1475+
1476+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
14661477
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
14671478

14681479
# Add noise according to flow matching.
@@ -1483,16 +1494,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14831494
model_pred = model_pred * (-sigmas) + noisy_model_input
14841495

14851496
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
1497+
# these weighting schemes use a uniform timestep sampling
1498+
# and instead post-weight the loss
14861499
if args.weighting_scheme == "sigma_sqrt":
14871500
weighting = (sigmas**-2.0).float()
1488-
elif args.weighting_scheme == "logit_normal":
1489-
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1490-
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
1491-
weighting = torch.nn.functional.sigmoid(u)
1492-
elif args.weighting_scheme == "mode":
1493-
# See sec 3.1 in the SD3 paper (20).
1494-
u = torch.rand(size=(bsz,), device=accelerator.device)
1495-
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
1501+
elif args.weighting_scheme == "cosmap":
1502+
bot = 1 - 2 * sigmas + 2 * sigmas**2
1503+
weighting = 2 / (math.pi * bot)
1504+
else:
1505+
weighting = torch.ones_like(sigmas)
14961506

14971507
# simplified flow matching aka 0-rectified flow matching loss
14981508
# target = model_input - noise

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15261526
bsz = model_input.shape[0]
15271527

15281528
# Sample a random timestep for each image
1529-
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
1529+
# for weighting schemes where we sample timesteps non-uniformly
1530+
if args.weighting_scheme == "logit_normal":
1531+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1532+
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
1533+
u = torch.nn.functional.sigmoid(u)
1534+
elif args.weighting_scheme == "mode":
1535+
u = torch.rand(size=(bsz,), device="cpu")
1536+
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
1537+
else:
1538+
u = torch.rand(size=(bsz,), device="cpu")
1539+
1540+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
15301541
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
15311542

15321543
# Add noise according to flow matching.
@@ -1560,18 +1571,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15601571
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
15611572
# Preconditioning of the model outputs.
15621573
model_pred = model_pred * (-sigmas) + noisy_model_input
1563-
1564-
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
1574+
# these weighting schemes use a uniform timestep sampling
1575+
# and instead post-weight the loss
15651576
if args.weighting_scheme == "sigma_sqrt":
15661577
weighting = (sigmas**-2.0).float()
1567-
elif args.weighting_scheme == "logit_normal":
1568-
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1569-
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
1570-
weighting = torch.nn.functional.sigmoid(u)
1571-
elif args.weighting_scheme == "mode":
1572-
# See sec 3.1 in the SD3 paper (20).
1573-
u = torch.rand(size=(bsz,), device=accelerator.device)
1574-
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
1578+
elif args.weighting_scheme == "cosmap":
1579+
bot = 1 - 2 * sigmas + 2 * sigmas**2
1580+
weighting = 2 / (math.pi * bot)
1581+
else:
1582+
weighting = torch.ones_like(sigmas)
15751583

15761584
# simplified flow matching aka 0-rectified flow matching loss
15771585
# target = model_input - noise

0 commit comments

Comments
 (0)