5353 StableDiffusion3Pipeline ,
5454)
5555from diffusers .optimization import get_scheduler
56- from diffusers .training_utils import cast_training_params
56+ from diffusers .training_utils import (
57+ cast_training_params ,
58+ compute_density_for_timestep_sampling ,
59+ compute_loss_weighting_for_sd3 ,
60+ )
5761from diffusers .utils import (
5862 check_min_version ,
5963 convert_unet_state_dict_to_peft ,
@@ -473,11 +477,20 @@ def parse_args(input_args=None):
473477 ),
474478 )
475479 parser .add_argument (
476- "--weighting_scheme" , type = str , default = "logit_normal" , choices = ["sigma_sqrt" , "logit_normal" , "mode" ]
480+ "--weighting_scheme" , type = str , default = "sigma_sqrt" , choices = ["sigma_sqrt" , "logit_normal" , "mode" , "cosmap" ]
481+ )
482+ parser .add_argument (
483+ "--logit_mean" , type = float , default = 0.0 , help = "mean to use when using the `'logit_normal'` weighting scheme."
484+ )
485+ parser .add_argument (
486+ "--logit_std" , type = float , default = 1.0 , help = "std to use when using the `'logit_normal'` weighting scheme."
487+ )
488+ parser .add_argument (
489+ "--mode_scale" ,
490+ type = float ,
491+ default = 1.29 ,
492+ help = "Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`." ,
477493 )
478- parser .add_argument ("--logit_mean" , type = float , default = 0.0 )
479- parser .add_argument ("--logit_std" , type = float , default = 1.0 )
480- parser .add_argument ("--mode_scale" , type = float , default = 1.29 )
481494 parser .add_argument (
482495 "--optimizer" ,
483496 type = str ,
@@ -1477,16 +1490,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14771490
14781491 # Sample a random timestep for each image
14791492 # for weighting schemes where we sample timesteps non-uniformly
1480- if args .weighting_scheme == "logit_normal" :
1481- # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
1482- u = torch .normal (mean = args .logit_mean , std = args .logit_std , size = (bsz ,), device = "cpu" )
1483- u = torch .nn .functional .sigmoid (u )
1484- elif args .weighting_scheme == "mode" :
1485- u = torch .rand (size = (bsz ,), device = "cpu" )
1486- u = 1 - u - args .mode_scale * (torch .cos (math .pi * u / 2 ) ** 2 - 1 + u )
1487- else :
1488- u = torch .rand (size = (bsz ,), device = "cpu" )
1489-
1493+ u = compute_density_for_timestep_sampling (
1494+ weighting_scheme = args .weighting_scheme ,
1495+ batch_size = bsz ,
1496+ logit_mean = args .logit_mean ,
1497+ logit_std = args .logit_std ,
1498+ mode_scale = args .mode_scale ,
1499+ )
14901500 indices = (u * noise_scheduler_copy .config .num_train_timesteps ).long ()
14911501 timesteps = noise_scheduler_copy .timesteps [indices ].to (device = model_input .device )
14921502
@@ -1507,19 +1517,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15071517 # Preconditioning of the model outputs.
15081518 model_pred = model_pred * (- sigmas ) + noisy_model_input
15091519
1510- # TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
15111520 # these weighting schemes use a uniform timestep sampling
15121521 # and instead post-weight the loss
1513- if args .weighting_scheme == "sigma_sqrt" :
1514- weighting = (sigmas ** - 2.0 ).float ()
1515- elif args .weighting_scheme == "cosmap" :
1516- bot = 1 - 2 * sigmas + 2 * sigmas ** 2
1517- weighting = 2 / (math .pi * bot )
1518- else :
1519- weighting = torch .ones_like (sigmas )
1522+ weighting = compute_loss_weighting_for_sd3 (weighting_scheme = args .weighting_scheme , sigmas = sigmas )
15201523
1521- # simplified flow matching aka 0-rectified flow matching loss
1522- # target = model_input - noise
1524+ # flow matching loss
15231525 target = model_input
15241526
15251527 if args .with_prior_preservation :
0 commit comments