Skip to content

Commit 2ea0689

Browse files
author
dmoi
committed
messing with encoder mk3
1 parent ab9e318 commit 2ea0689

File tree

3 files changed

+840
-107
lines changed

3 files changed

+840
-107
lines changed

foldtree2/notebooks/experiments/test_monodecoders.ipynb

Lines changed: 789 additions & 91 deletions
Large diffs are not rendered by default.
1.48 KB
Binary file not shown.

foldtree2/src/mono_decoders.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ def __init__(
923923
dropout=0.001,
924924
normalize=True,
925925
residual=True,
926+
output_ss=False,
926927
**kwargs
927928
):
928929
super(Transformer_AA_Decoder, self).__init__()
@@ -951,18 +952,7 @@ def __init__(
951952
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nheads, dropout=dropout)
952953
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers)
953954

954-
self.dnn_decoder = torch.nn.Sequential(
955-
#layernorm
956-
torch.nn.LayerNorm(d_model),
957-
torch.nn.Linear(d_model, AAdecoder_hidden[0]),
958-
torch.nn.GELU(),
959-
torch.nn.Linear(AAdecoder_hidden[0], AAdecoder_hidden[1]),
960-
torch.nn.GELU(),
961-
torch.nn.Linear(AAdecoder_hidden[1], AAdecoder_hidden[2]),
962-
torch.nn.GELU(),
963-
torch.nn.Linear(AAdecoder_hidden[2], 20),
964-
torch.nn.LogSoftmax(dim=1)
965-
)
955+
966956

967957
self.cnn_decoder = None
968958
if use_cnn_decoder := kwargs.get('use_cnn_decoder', False):
@@ -981,6 +971,32 @@ def __init__(
981971
torch.nn.Conv1d(AAdecoder_hidden[2], 20, kernel_size=1),
982972
# Transpose back to (seq_len, batch, features) and apply softmax
983973
)
974+
975+
else:
976+
self.dnn_decoder = torch.nn.Sequential(
977+
#layernorm
978+
torch.nn.LayerNorm(d_model),
979+
torch.nn.Linear(d_model, AAdecoder_hidden[0]),
980+
torch.nn.GELU(),
981+
torch.nn.Linear(AAdecoder_hidden[0], AAdecoder_hidden[1]),
982+
torch.nn.GELU(),
983+
torch.nn.Linear(AAdecoder_hidden[1], AAdecoder_hidden[2]),
984+
torch.nn.GELU(),
985+
torch.nn.Linear(AAdecoder_hidden[2], 20),
986+
torch.nn.LogSoftmax(dim=1)
987+
)
988+
989+
if output_ss:
990+
# Secondary structure head
991+
self.ss_head = torch.nn.Sequential(
992+
torch.nn.Linear(d_model, AAdecoder_hidden[0]),
993+
torch.nn.GELU(),
994+
torch.nn.Linear(AAdecoder_hidden[0], AAdecoder_hidden[1]),
995+
torch.nn.GELU(),
996+
torch.nn.Linear(AAdecoder_hidden[1], 3),
997+
torch.nn.LogSoftmax(dim=1)
998+
)
999+
9841000

9851001
def forward(self, data, **kwargs):
9861002
x = data.x_dict['res']
@@ -1008,35 +1024,47 @@ def forward(self, data, **kwargs):
10081024
x = x.unsqueeze(1) # (seq_len, 1, d_model)
10091025

10101026
x = self.transformer_encoder(x) # (N, batch, d_model)
1011-
1027+
ss = None
10121028
if batch is not None:
10131029
# Remove padding and concatenate results for all graphs in the batch
10141030
aa_list = []
1031+
ss_list = []
10151032
for i, xi in enumerate(x.split(1, dim=1)): # xi: (seq_len, 1, d_model)
10161033
# Remove batch dimension and padding (assume original lengths from batch)
10171034
seq_len = (batch == i).sum().item()
10181035
if self.cnn_decoder is not None:
10191036
# Apply CNN decoder
10201037
xi = self.prenorm(xi.squeeze(1)) # (seq_len, d_model)
1038+
if self.ss_head is not None:
1039+
ss_list.append(self.ss_head(xi[:seq_len, :]))
10211040
xi_cnn = xi.permute(1, 0).unsqueeze(0) # (1, d_model, seq_len)
10221041
xi_cnn = self.cnn_decoder(xi_cnn) # (1, 20, seq_len)
10231042
xi_cnn = xi_cnn.permute(2, 0, 1).squeeze(1) # (seq_len, 20)
10241043
aa_list.append(F.log_softmax(xi_cnn[:seq_len, :], dim=-1))
10251044
else:
10261045
aa_list.append(self.dnn_decoder(xi[:seq_len, 0]))
1046+
if self.ss_head is not None:
1047+
ss_list.append(self.ss_head(xi[:seq_len, 0]))
10271048
aa = torch.cat(aa_list, dim=0)
1028-
return {'aa': aa }
1049+
if self.ss_head is not None:
1050+
ss = torch.cat(ss_list, dim=0)
1051+
return {'aa': aa, 'ss_pred': ss }
1052+
10291053
else:
10301054
if self.cnn_decoder is not None:
10311055
# Apply CNN decoder
10321056
x = self.prenorm(x)
1057+
if self.ss_head is not None:
1058+
ss=self.ss_head(x)
10331059
x_cnn = x.permute(1, 2, 0) # (batch, d_model, seq_len)
10341060
x_cnn = self.cnn_decoder(x_cnn) # (batch, xdim, seq_len)
10351061
x_cnn = x_cnn.permute(2, 0, 1) # (seq_len, batch, xdim)
10361062
aa = F.log_softmax(x_cnn, dim=-1)
10371063
else:
10381064
aa = self.dnn_decoder(x)
1039-
return {'aa': aa}
1065+
if self.ss_head is not None:
1066+
ss = self.ss_head(x)
1067+
return {'aa': aa, 'ss_pred': ss}
10401068

10411069
def x_to_amino_acid_sequence(self, x_r):
10421070
indices = torch.argmax(x_r, dim=1)
@@ -1374,7 +1402,14 @@ def __init__(self, configs):
13741402
def forward(self, data, contact_pred_index=None, **kwargs):
13751403
results = {}
13761404
for task, decoder in self.decoders.items():
1377-
results.update(decoder(data, contact_pred_index=contact_pred_index, **kwargs))
1405+
#if a decoder returns a value for a key that already exists in results that is none
1406+
#and existing value is not none, keep the existing value
1407+
#otherwise, update the results with the new value
1408+
for key, value in decoder(data, contact_pred_index=contact_pred_index, **kwargs).items():
1409+
if key in results and results[key] is not None and value is None:
1410+
continue
1411+
else:
1412+
results[key] = value
13781413
return results
13791414

13801415

0 commit comments

Comments
 (0)