@@ -1526,7 +1526,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1526
1526
bsz = model_input .shape [0 ]
1527
1527
1528
1528
# Sample a random timestep for each image
1529
- indices = torch .randint (0 , noise_scheduler_copy .config .num_train_timesteps , (bsz ,))
1529
+ # for weighting schemes where we sample timesteps non-uniformly
1530
+ if args .weighting_scheme == "logit_normal" :
1531
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1532
+ u = torch .normal (mean = args .logit_mean , std = args .logit_std , size = (bsz ,), device = "cpu" )
1533
+ u = torch .nn .functional .sigmoid (u )
1534
+ elif args .weighting_scheme == "mode" :
1535
+ u = torch .rand (size = (bsz ,), device = "cpu" )
1536
+ u = 1 - u - args .mode_scale * (torch .cos (math .pi * u / 2 ) ** 2 - 1 + u )
1537
+ else :
1538
+ u = torch .rand (size = (bsz ,), device = "cpu" )
1539
+
1540
+ indices = (u * noise_scheduler_copy .config .num_train_timesteps ).long ()
1530
1541
timesteps = noise_scheduler_copy .timesteps [indices ].to (device = model_input .device )
1531
1542
1532
1543
# Add noise according to flow matching.
@@ -1560,18 +1571,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1560
1571
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1561
1572
# Preconditioning of the model outputs.
1562
1573
model_pred = model_pred * (- sigmas ) + noisy_model_input
1563
-
1564
- # TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
1574
+ # these weighting schemes use a uniform timestep sampling
1575
+ # and instead post-weight the loss
1565
1576
if args .weighting_scheme == "sigma_sqrt" :
1566
1577
weighting = (sigmas ** - 2.0 ).float ()
1567
- elif args .weighting_scheme == "logit_normal" :
1568
- # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1569
- u = torch .normal (mean = args .logit_mean , std = args .logit_std , size = (bsz ,), device = accelerator .device )
1570
- weighting = torch .nn .functional .sigmoid (u )
1571
- elif args .weighting_scheme == "mode" :
1572
- # See sec 3.1 in the SD3 paper (20).
1573
- u = torch .rand (size = (bsz ,), device = accelerator .device )
1574
- weighting = 1 - u - args .mode_scale * (torch .cos (math .pi * u / 2 ) ** 2 - 1 + u )
1578
+ elif args .weighting_scheme == "cosmap" :
1579
+ bot = 1 - 2 * sigmas + 2 * sigmas ** 2
1580
+ weighting = 2 / (math .pi * bot )
1581
+ else :
1582
+ weighting = torch .ones_like (sigmas )
1575
1583
1576
1584
# simplified flow matching aka 0-rectified flow matching loss
1577
1585
# target = model_input - noise
0 commit comments