Skip to content

Commit 4c62149

Browse files
author
Joseph Watson
committed
Initial commit to RFdiffusion
0 parents  commit 4c62149

File tree

108 files changed

+38150
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+38150
-0
lines changed

Attention_module.py

Lines changed: 404 additions & 0 deletions
Large diffs are not rendered by default.

AuxiliaryPredictor.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class DistanceNetwork(nn.Module):
5+
def __init__(self, n_feat, p_drop=0.1):
6+
super(DistanceNetwork, self).__init__()
7+
#
8+
self.proj_symm = nn.Linear(n_feat, 37*2)
9+
self.proj_asymm = nn.Linear(n_feat, 37+19)
10+
11+
self.reset_parameter()
12+
13+
def reset_parameter(self):
14+
# initialize linear layer for final logit prediction
15+
nn.init.zeros_(self.proj_symm.weight)
16+
nn.init.zeros_(self.proj_asymm.weight)
17+
nn.init.zeros_(self.proj_symm.bias)
18+
nn.init.zeros_(self.proj_asymm.bias)
19+
20+
def forward(self, x):
21+
# input: pair info (B, L, L, C)
22+
23+
# predict theta, phi (non-symmetric)
24+
logits_asymm = self.proj_asymm(x)
25+
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
26+
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
27+
28+
# predict dist, omega
29+
logits_symm = self.proj_symm(x)
30+
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
31+
logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2)
32+
logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
33+
34+
return logits_dist, logits_omega, logits_theta, logits_phi
35+
36+
class MaskedTokenNetwork(nn.Module):
37+
def __init__(self, n_feat, p_drop=0.1):
38+
super(MaskedTokenNetwork, self).__init__()
39+
self.proj = nn.Linear(n_feat, 21)
40+
41+
self.reset_parameter()
42+
43+
def reset_parameter(self):
44+
nn.init.zeros_(self.proj.weight)
45+
nn.init.zeros_(self.proj.bias)
46+
47+
def forward(self, x):
48+
B, N, L = x.shape[:3]
49+
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
50+
51+
return logits
52+
53+
class LDDTNetwork(nn.Module):
54+
def __init__(self, n_feat, n_bin_lddt=50):
55+
super(LDDTNetwork, self).__init__()
56+
self.proj = nn.Linear(n_feat, n_bin_lddt)
57+
58+
self.reset_parameter()
59+
60+
def reset_parameter(self):
61+
nn.init.zeros_(self.proj.weight)
62+
nn.init.zeros_(self.proj.bias)
63+
64+
def forward(self, x):
65+
logits = self.proj(x) # (B, L, 50)
66+
67+
return logits.permute(0,2,1)
68+
69+
class ExpResolvedNetwork(nn.Module):
70+
def __init__(self, d_msa, d_state, p_drop=0.1):
71+
super(ExpResolvedNetwork, self).__init__()
72+
self.norm_msa = nn.LayerNorm(d_msa)
73+
self.norm_state = nn.LayerNorm(d_state)
74+
self.proj = nn.Linear(d_msa+d_state, 1)
75+
76+
self.reset_parameter()
77+
78+
def reset_parameter(self):
79+
nn.init.zeros_(self.proj.weight)
80+
nn.init.zeros_(self.proj.bias)
81+
82+
def forward(self, seq, state):
83+
B, L = seq.shape[:2]
84+
85+
seq = self.norm_msa(seq)
86+
state = self.norm_state(state)
87+
feat = torch.cat((seq, state), dim=-1)
88+
logits = self.proj(feat)
89+
return logits.reshape(B, L)
90+
91+
92+

0 commit comments

Comments
 (0)