@@ -923,6 +923,7 @@ def __init__(
923923 dropout = 0.001 ,
924924 normalize = True ,
925925 residual = True ,
926+ output_ss = False ,
926927 ** kwargs
927928 ):
928929 super (Transformer_AA_Decoder , self ).__init__ ()
@@ -951,18 +952,7 @@ def __init__(
951952 encoder_layer = nn .TransformerEncoderLayer (d_model = d_model , nhead = nheads , dropout = dropout )
952953 self .transformer_encoder = nn .TransformerEncoder (encoder_layer , num_layers = layers )
953954
954- self .dnn_decoder = torch .nn .Sequential (
955- #layernorm
956- torch .nn .LayerNorm (d_model ),
957- torch .nn .Linear (d_model , AAdecoder_hidden [0 ]),
958- torch .nn .GELU (),
959- torch .nn .Linear (AAdecoder_hidden [0 ], AAdecoder_hidden [1 ]),
960- torch .nn .GELU (),
961- torch .nn .Linear (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ]),
962- torch .nn .GELU (),
963- torch .nn .Linear (AAdecoder_hidden [2 ], 20 ),
964- torch .nn .LogSoftmax (dim = 1 )
965- )
955+
966956
967957 self .cnn_decoder = None
968958 if use_cnn_decoder := kwargs .get ('use_cnn_decoder' , False ):
@@ -981,6 +971,32 @@ def __init__(
981971 torch .nn .Conv1d (AAdecoder_hidden [2 ], 20 , kernel_size = 1 ),
982972 # Transpose back to (seq_len, batch, features) and apply softmax
983973 )
974+
975+ else :
976+ self .dnn_decoder = torch .nn .Sequential (
977+ #layernorm
978+ torch .nn .LayerNorm (d_model ),
979+ torch .nn .Linear (d_model , AAdecoder_hidden [0 ]),
980+ torch .nn .GELU (),
981+ torch .nn .Linear (AAdecoder_hidden [0 ], AAdecoder_hidden [1 ]),
982+ torch .nn .GELU (),
983+ torch .nn .Linear (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ]),
984+ torch .nn .GELU (),
985+ torch .nn .Linear (AAdecoder_hidden [2 ], 20 ),
986+ torch .nn .LogSoftmax (dim = 1 )
987+ )
988+
989+ if output_ss :
990+ # Secondary structure head
991+ self .ss_head = torch .nn .Sequential (
992+ torch .nn .Linear (d_model , AAdecoder_hidden [0 ]),
993+ torch .nn .GELU (),
994+ torch .nn .Linear (AAdecoder_hidden [0 ], AAdecoder_hidden [1 ]),
995+ torch .nn .GELU (),
996+ torch .nn .Linear (AAdecoder_hidden [1 ], 3 ),
997+ torch .nn .LogSoftmax (dim = 1 )
998+ )
999+
9841000
9851001 def forward (self , data , ** kwargs ):
9861002 x = data .x_dict ['res' ]
@@ -1008,35 +1024,47 @@ def forward(self, data, **kwargs):
10081024 x = x .unsqueeze (1 ) # (seq_len, 1, d_model)
10091025
10101026 x = self .transformer_encoder (x ) # (N, batch, d_model)
1011-
1027+ ss = None
10121028 if batch is not None :
10131029 # Remove padding and concatenate results for all graphs in the batch
10141030 aa_list = []
1031+ ss_list = []
10151032 for i , xi in enumerate (x .split (1 , dim = 1 )): # xi: (seq_len, 1, d_model)
10161033 # Remove batch dimension and padding (assume original lengths from batch)
10171034 seq_len = (batch == i ).sum ().item ()
10181035 if self .cnn_decoder is not None :
10191036 # Apply CNN decoder
10201037 xi = self .prenorm (xi .squeeze (1 )) # (seq_len, d_model)
1038+ if self .ss_head is not None :
1039+ ss_list .append (self .ss_head (xi [:seq_len , :]))
10211040 xi_cnn = xi .permute (1 , 0 ).unsqueeze (0 ) # (1, d_model, seq_len)
10221041 xi_cnn = self .cnn_decoder (xi_cnn ) # (1, 20, seq_len)
10231042 xi_cnn = xi_cnn .permute (2 , 0 , 1 ).squeeze (1 ) # (seq_len, 20)
10241043 aa_list .append (F .log_softmax (xi_cnn [:seq_len , :], dim = - 1 ))
10251044 else :
10261045 aa_list .append (self .dnn_decoder (xi [:seq_len , 0 ]))
1046+ if self .ss_head is not None :
1047+ ss_list .append (self .ss_head (xi [:seq_len , 0 ]))
10271048 aa = torch .cat (aa_list , dim = 0 )
1028- return {'aa' : aa }
1049+ if self .ss_head is not None :
1050+ ss = torch .cat (ss_list , dim = 0 )
1051+ return {'aa' : aa , 'ss_pred' : ss }
1052+
10291053 else :
10301054 if self .cnn_decoder is not None :
10311055 # Apply CNN decoder
10321056 x = self .prenorm (x )
1057+ if self .ss_head is not None :
1058+ ss = self .ss_head (x )
10331059 x_cnn = x .permute (1 , 2 , 0 ) # (batch, d_model, seq_len)
10341060 x_cnn = self .cnn_decoder (x_cnn ) # (batch, xdim, seq_len)
10351061 x_cnn = x_cnn .permute (2 , 0 , 1 ) # (seq_len, batch, xdim)
10361062 aa = F .log_softmax (x_cnn , dim = - 1 )
10371063 else :
10381064 aa = self .dnn_decoder (x )
1039- return {'aa' : aa }
1065+ if self .ss_head is not None :
1066+ ss = self .ss_head (x )
1067+ return {'aa' : aa , 'ss_pred' : ss }
10401068
10411069 def x_to_amino_acid_sequence (self , x_r ):
10421070 indices = torch .argmax (x_r , dim = 1 )
@@ -1374,7 +1402,14 @@ def __init__(self, configs):
13741402 def forward (self , data , contact_pred_index = None , ** kwargs ):
13751403 results = {}
13761404 for task , decoder in self .decoders .items ():
1377- results .update (decoder (data , contact_pred_index = contact_pred_index , ** kwargs ))
1405+ #if a decoder returns a value for a key that already exists in results that is none
1406+ #and existing value is not none, keep the existing value
1407+ #otherwise, update the results with the new value
1408+ for key , value in decoder (data , contact_pred_index = contact_pred_index , ** kwargs ).items ():
1409+ if key in results and results [key ] is not None and value is None :
1410+ continue
1411+ else :
1412+ results [key ] = value
13781413 return results
13791414
13801415
0 commit comments