Skip to content

Commit 0f41a5f

Browse files
author
dmoi
committed
angle training working
1 parent 98f86e7 commit 0f41a5f

File tree

3 files changed

+3004
-106
lines changed

3 files changed

+3004
-106
lines changed

foldtree2/notebooks/experiments/test_monodecoders.ipynb

Lines changed: 2976 additions & 81 deletions
Large diffs are not rendered by default.
-237 Bytes
Binary file not shown.

foldtree2/src/losses/losses.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@
66
EPS = 1e-8
77

88
def jensen_shannon_regularization(encodings):
9-
# 1) Compute the average distribution p
10-
p = encodings.mean(dim=0)
11-
12-
# 2) Define uniform distribution u
13-
K = p.size(0)
14-
u = torch.ones_like(p) / K
15-
16-
# 3) Compute the midpoint m = (p + u) / 2
17-
m = 0.5 * (p + u)
18-
19-
# 4) Use the definition of JSD(p || u):
20-
# JSD(p || u) = 0.5 * KL(p || m) + 0.5 * KL(u || m)
21-
# KL(x || y) = sum( x_i * log(x_i / y_i) )
22-
eps = 1e-10
23-
24-
kl_p_m = torch.sum(p * torch.log((p + eps) / (m + eps)))
25-
kl_u_m = torch.sum(u * torch.log((u + eps) / (m + eps)))
26-
27-
jsd = 0.5 * kl_p_m + 0.5 * kl_u_m
28-
return jsd
9+
# 1) Compute the average distribution p
10+
p = encodings.mean(dim=0)
11+
12+
# 2) Define uniform distribution u
13+
K = p.size(0)
14+
u = torch.ones_like(p) / K
15+
16+
# 3) Compute the midpoint m = (p + u) / 2
17+
m = 0.5 * (p + u)
18+
19+
# 4) Use the definition of JSD(p || u):
20+
# JSD(p || u) = 0.5 * KL(p || m) + 0.5 * KL(u || m)
21+
# KL(x || y) = sum( x_i * log(x_i / y_i) )
22+
eps = 1e-10
23+
24+
kl_p_m = torch.sum(p * torch.log((p + eps) / (m + eps)))
25+
kl_u_m = torch.sum(u * torch.log((u + eps) / (m + eps)))
26+
27+
jsd = 0.5 * kl_p_m + 0.5 * kl_u_m
28+
return jsd
2929

3030
#jaccard distance multiset loss for protein pairs
3131

@@ -157,13 +157,16 @@ def aa_reconstruction_loss(x, recon_x):
157157

158158
"""
159159
def angles_reconstruction_loss(true, pred):
160-
delta = pred - true
161-
return (1.0 - torch.cos(delta)).mean()
160+
delta = pred - true
161+
return (1.0 - torch.cos(delta)).mean()
162162
"""
163163

164-
def angles_reconstruction_loss(true, pred, beta=0.5):
165-
delta = torch.atan2(torch.sin(pred - true), torch.cos(pred - true))
166-
return F.smooth_l1_loss(delta, torch.zeros_like(delta), beta=beta)
164+
def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None):
165+
delta = torch.atan2(torch.sin(pred - true), torch.cos(pred - true))
166+
loss = F.smooth_l1_loss(delta, torch.zeros_like(delta), beta=beta)
167+
if plddt_mask is not None:
168+
loss = loss * plddt_mask
169+
return loss.mean()
167170

168171

169172
def gaussian_loss(mu , logvar , beta= 1.5):

0 commit comments

Comments
 (0)