Skip to content

The density_for_timestep_sampling and loss_weighting for SD3 Training!!! #9056

@DidiD1

Description

@DidiD1

Thanks to Rafie Walker's code we can try to train SD3 models with flow-matching!
But some places don't seem to match what's in the paper.
Rafie Walker's code is below:

def compute_density_for_timestep_sampling(
    weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
    if weighting_scheme == "logit_normal":
        # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
        u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
        u = torch.nn.functional.sigmoid(u)
    elif weighting_scheme == "mode":
        u = torch.rand(size=(batch_size,), device="cpu")
        u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
    else:
        u = torch.rand(size=(batch_size,), device="cpu")
    return u

def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
    if weighting_scheme == "sigma_sqrt":
        weighting = (sigmas**-2.0).float()
    elif weighting_scheme == "cosmap":
        bot = 1 - 2 * sigmas + 2 * sigmas**2
        weighting = 2 / (math.pi * bot)
    else:
        weighting = torch.ones_like(sigmas)
    return weighting

My question is below:

  1. when weighting_scheme == "mode“, the code only compute the f_mode. If you need to compute 'u', you should some additional operation?
  2. Cos-map seems to compute the weight of timesteps, not the weight of loss?
  3. when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?

So I think there need some modify to correctly compute the loss of SD3!
Thanks for discussion together!

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions