-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss.py
More file actions
21 lines (18 loc) · 834 Bytes
/
loss.py
File metadata and controls
21 lines (18 loc) · 834 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
DEVICE = "cuda"
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
random_t = torch.rand(x.shape[0], device=x.device) * (1. - 2 * eps) + eps
std = marginal_prob_std(random_t)
z = torch.randn_like(x)
perturbed_x = x + z * std[:, None, None, None]
score = model(perturbed_x, random_t)
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3)))
return loss
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5):
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
z = torch.randn_like(x)
std = marginal_prob_std(random_t)
perturbed_x = x + z * std[:, None, None, None]
score = model(perturbed_x, random_t, y=y)
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3)))
return loss