Skip to content

[feat] (vdm.py): add loss parametrization to image prediction #1

@relyativist

Description

@relyativist

according to doi.org/10.48550/arXiv.2401.06281, 4.2, equation 4.62, image prediction loss can be parametrized by adding the snr_t term to pred_loss . Thus for image prediction in VDM.forward() we change:

 - pred_loss = ((model_out - noise) ** 2).sum((1, 2, 3))
+ snr_t = torch.exp(-gamma_t) 
+ pred_loss = snr_t * ((model_out - x) ** 2).sum((1, 2, 3))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions