|
6 | 6 | EPS = 1e-8 |
7 | 7 |
|
8 | 8 | def jensen_shannon_regularization(encodings): |
9 | | - # 1) Compute the average distribution p |
10 | | - p = encodings.mean(dim=0) |
11 | | - |
12 | | - # 2) Define uniform distribution u |
13 | | - K = p.size(0) |
14 | | - u = torch.ones_like(p) / K |
15 | | - |
16 | | - # 3) Compute the midpoint m = (p + u) / 2 |
17 | | - m = 0.5 * (p + u) |
18 | | - |
19 | | - # 4) Use the definition of JSD(p || u): |
20 | | - # JSD(p || u) = 0.5 * KL(p || m) + 0.5 * KL(u || m) |
21 | | - # KL(x || y) = sum( x_i * log(x_i / y_i) ) |
22 | | - eps = 1e-10 |
23 | | - |
24 | | - kl_p_m = torch.sum(p * torch.log((p + eps) / (m + eps))) |
25 | | - kl_u_m = torch.sum(u * torch.log((u + eps) / (m + eps))) |
26 | | - |
27 | | - jsd = 0.5 * kl_p_m + 0.5 * kl_u_m |
28 | | - return jsd |
| 9 | + # 1) Compute the average distribution p |
| 10 | + p = encodings.mean(dim=0) |
| 11 | + |
| 12 | + # 2) Define uniform distribution u |
| 13 | + K = p.size(0) |
| 14 | + u = torch.ones_like(p) / K |
| 15 | + |
| 16 | + # 3) Compute the midpoint m = (p + u) / 2 |
| 17 | + m = 0.5 * (p + u) |
| 18 | + |
| 19 | + # 4) Use the definition of JSD(p || u): |
| 20 | + # JSD(p || u) = 0.5 * KL(p || m) + 0.5 * KL(u || m) |
| 21 | + # KL(x || y) = sum( x_i * log(x_i / y_i) ) |
| 22 | + eps = 1e-10 |
| 23 | + |
| 24 | + kl_p_m = torch.sum(p * torch.log((p + eps) / (m + eps))) |
| 25 | + kl_u_m = torch.sum(u * torch.log((u + eps) / (m + eps))) |
| 26 | + |
| 27 | + jsd = 0.5 * kl_p_m + 0.5 * kl_u_m |
| 28 | + return jsd |
29 | 29 |
|
30 | 30 | #jaccard distance multiset loss for protein pairs |
31 | 31 |
|
@@ -157,13 +157,16 @@ def aa_reconstruction_loss(x, recon_x): |
157 | 157 |
|
158 | 158 | """ |
159 | 159 | def angles_reconstruction_loss(true, pred): |
160 | | - delta = pred - true |
161 | | - return (1.0 - torch.cos(delta)).mean() |
| 160 | + delta = pred - true |
| 161 | + return (1.0 - torch.cos(delta)).mean() |
162 | 162 | """ |
163 | 163 |
|
164 | | -def angles_reconstruction_loss(true, pred, beta=0.5): |
165 | | - delta = torch.atan2(torch.sin(pred - true), torch.cos(pred - true)) |
166 | | - return F.smooth_l1_loss(delta, torch.zeros_like(delta), beta=beta) |
| 164 | +def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None): |
| 165 | + delta = torch.atan2(torch.sin(pred - true), torch.cos(pred - true)) |
| 166 | + loss = F.smooth_l1_loss(delta, torch.zeros_like(delta), beta=beta) |
| 167 | + if plddt_mask is not None: |
| 168 | + loss = loss * plddt_mask |
| 169 | + return loss.mean() |
167 | 170 |
|
168 | 171 |
|
169 | 172 | def gaussian_loss(mu , logvar , beta= 1.5): |
|
0 commit comments