-
Couldn't load subscription status.
- Fork 6.5k
Open
Labels
staleIssues that haven't received updatesIssues that haven't received updates
Description
Thanks to Rafie Walker's code we can try to train SD3 models with flow-matching!
But some places don't seem to match what's in the paper.
Rafie Walker's code is below:
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
My question is below:
- when weighting_scheme == "mode“, the code only compute the f_mode. If you need to compute 'u', you should some additional operation?
- Cos-map seems to compute the weight of timesteps, not the weight of loss?
- when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?
So I think there need some modify to correctly compute the loss of SD3!
Thanks for discussion together!
Metadata
Metadata
Assignees
Labels
staleIssues that haven't received updatesIssues that haven't received updates