@@ -3099,31 +3099,22 @@ def __init__(
30993099 @typecheck
31003100 def forward (
31013101 self ,
3102- pred_coords : Float ['b m_or_n 3' ],
3103- true_coords : Float ['b m_or_n 3' ],
3102+ pred_coords : Float ['b m 3' ],
3103+ true_coords : Float ['b m 3' ],
31043104 pred_frames : Float ['b n 3 3' ],
31053105 true_frames : Float ['b n 3 3' ],
3106- mask : Bool ['b m_or_n' ] | None = None ,
3107- molecule_atom_lens : Int ['b n ' ] | None = None
3108- ) -> Float ['b m_or_n m_or_n ' ]:
3106+ molecule_atom_lens : Int ['b n' ] ,
3107+ mask : Bool ['b m ' ] | None = None ,
3108+ ) -> Float ['b m m ' ]:
31093109 """
31103110 pred_coords: predicted coordinates
31113111 true_coords: true coordinates
31123112 pred_frames: predicted frames
31133113 true_frames: true frames
31143114 """
31153115
3116- # detect whether using atom or residue resolution
3117-
3118- is_atom_resolution = pred_coords .shape [1 ] != pred_frames .shape [1 ]
3119- assert not is_atom_resolution or exists (molecule_atom_lens ), '`molecule_atom_lens` must be passed in for atom resolution alignment error'
3120-
3121- if is_atom_resolution :
3122- pred_frames = batch_repeat_interleave (pred_frames , molecule_atom_lens )
3123- true_frames = batch_repeat_interleave (true_frames , molecule_atom_lens )
3124-
3125- if not exists (mask ) and exists (molecule_atom_lens ):
3126- mask = batch_repeat_interleave (molecule_atom_lens > 0 , molecule_atom_lens )
3116+ pred_frames = batch_repeat_interleave (pred_frames , molecule_atom_lens )
3117+ true_frames = batch_repeat_interleave (true_frames , molecule_atom_lens )
31273118
31283119 # to pairs
31293120
@@ -3445,7 +3436,7 @@ def forward(
34453436
34463437class ConfidenceHeadLogits (NamedTuple ):
34473438 pae : Float ['b pae m m' ] | None
3448- pde : Float ['b pde m m ' ]
3439+ pde : Float ['b pde n n ' ]
34493440 plddt : Float ['b plddt m' ]
34503441 resolved : Float ['b 2 m' ]
34513442
@@ -3514,6 +3505,8 @@ def __init__(
35143505
35153506 self .atom_feats_to_single = LinearNoBias (dim_atom , dim_single )
35163507
3508+ self .atom_feats_to_pairwise = LinearNoBiasThenOuterSum (dim_atom , dim_pairwise )
3509+
35173510 # tensor typing
35183511
35193512 self .da = dim_atom
@@ -3570,12 +3563,16 @@ def forward(
35703563 single_repr = single_repr , pairwise_repr = pairwise_repr , mask = mask
35713564 )
35723565
3573- # handle atom level resolution
3566+ # handle atom level resolution for single and pairwise
35743567
35753568 atom_single_repr = batch_repeat_interleave (single_repr , molecule_atom_lens )
35763569
35773570 atom_single_repr = atom_single_repr + self .atom_feats_to_single (atom_feats )
35783571
3572+ atom_pairwise_repr = batch_repeat_interleave_pairwise (pairwise_repr , molecule_atom_lens )
3573+
3574+ atom_pairwise_repr = atom_pairwise_repr + self .atom_feats_to_pairwise (atom_feats )
3575+
35793576 # to logits
35803577
35813578 pde_logits = self .to_pde_logits (symmetrize (pairwise_repr ))
@@ -3588,7 +3585,7 @@ def forward(
35883585 pae_logits = None
35893586
35903587 if return_pae_logits :
3591- pae_logits = self .to_pae_logits (pairwise_repr )
3588+ pae_logits = self .to_pae_logits (atom_pairwise_repr )
35923589
35933590 # return all logits
35943591
@@ -3642,11 +3639,12 @@ def _calculate_bin_centers(
36423639 @typecheck
36433640 def forward (
36443641 self ,
3645- confidence_head_logits : ConfidenceHeadLogits ,
3642+ pae_logits : Float ["b pae m m" ],
3643+ plddt_logits : Float ["b plddt m" ],
36463644 asym_id : Int ["b n" ],
36473645 has_frame : Bool ["b n" ],
3646+ molecule_atom_lens : Int ["b n" ],
36483647 ptm_residue_weight : Float ["b n" ] | None = None ,
3649- molecule_atom_lens : Int ["b n" ] | None = None ,
36503648 multimer_mode : bool = True ,
36513649 ) -> ConfidenceScore :
36523650 """Main function to compute confidence score.
@@ -3658,21 +3656,19 @@ def forward(
36583656 :param multimer_mode: bool
36593657 :return: Confidence score
36603658 """
3661- plddt = self .compute_plddt (confidence_head_logits . plddt )
3659+ plddt = self .compute_plddt (plddt_logits )
36623660
36633661 # Section 5.9.1 equation 17
36643662 ptm = self .compute_ptm (
3665- confidence_head_logits .pae , asym_id , has_frame , ptm_residue_weight , interface = False ,
3666- molecule_atom_lens = molecule_atom_lens ,
3663+ pae_logits , asym_id , has_frame , molecule_atom_lens , ptm_residue_weight , interface = False
36673664 )
36683665
36693666 iptm = None
36703667
36713668 if multimer_mode :
36723669 # Section 5.9.2 equation 18
36733670 iptm = self .compute_ptm (
3674- confidence_head_logits .pae , asym_id , has_frame , ptm_residue_weight , interface = True ,
3675- molecule_atom_lens = molecule_atom_lens ,
3671+ pae_logits , asym_id , has_frame , molecule_atom_lens , ptm_residue_weight , interface = True
36763672 )
36773673
36783674 confidence_score = ConfidenceScore (plddt = plddt , ptm = ptm , iptm = iptm )
@@ -3700,11 +3696,11 @@ def compute_plddt(
37003696 @typecheck
37013697 def compute_ptm (
37023698 self ,
3703- logits : Float ["b pae m_or_n m_or_n " ],
3699+ logits : Float ["b pae m m " ],
37043700 asym_id : Int ["b n" ],
37053701 has_frame : Bool ["b n" ],
3702+ molecule_atom_lens : Int ["b n" ],
37063703 residue_weights : Float ["b n" ] | None = None ,
3707- molecule_atom_lens : Int ["b n" ] | None = None ,
37083704 interface : bool = False ,
37093705 compute_chain_wise_iptm : bool = False ,
37103706 ) -> Float [" b" ] | Tuple [Float ["b chains chains" ], Bool ["b chains chains" ], Int ["b chains" ]]:
@@ -3720,14 +3716,11 @@ def compute_ptm(
37203716 :return: pTM
37213717 """
37223718
3723- is_atom_resolution = logits .shape [- 1 ] != asym_id .shape [- 1 ]
3724- assert not is_atom_resolution or exists (molecule_atom_lens ), '`molecule_atom_lens` must be passed in for atom resolution pTM'
37253719
3726- if is_atom_resolution :
3727- asym_id = batch_repeat_interleave (asym_id , molecule_atom_lens )
3728- has_frame = batch_repeat_interleave (has_frame , molecule_atom_lens )
3729- if exists (residue_weights ):
3730- residue_weights = batch_repeat_interleave (residue_weights , molecule_atom_lens )
3720+ asym_id = batch_repeat_interleave (asym_id , molecule_atom_lens )
3721+ has_frame = batch_repeat_interleave (has_frame , molecule_atom_lens )
3722+ if exists (residue_weights ):
3723+ residue_weights = batch_repeat_interleave (residue_weights , molecule_atom_lens )
37313724
37323725 if not exists (residue_weights ):
37333726 residue_weights = torch .ones_like (has_frame )
@@ -3988,10 +3981,10 @@ def compute_disorder(
39883981 disorder = ((atom_rasa > 0.581 ) * mask ).sum (dim = - 1 ) / (self .eps + mask .sum (dim = 1 ))
39893982 return disorder
39903983
3991- @typecheck
39923984 def compute_full_complex_metric (
39933985 self ,
3994- confidence_head_logits : ConfidenceHeadLogits ,
3986+ pae_logits : Float ['b pae m m' ],
3987+ plddt_logits : Float ['b plddt m' ],
39953988 asym_id : Int ["b n" ],
39963989 has_frame : Bool ["b n" ],
39973990 molecule_atom_lens : Int ["b n" ],
@@ -4003,7 +3996,8 @@ def compute_full_complex_metric(
40033996
40043997 """Compute full complex metric.
40053998
4006- :param confidence_head_logits: ConfidenceHeadLogits
3999+ :param pae_logits: pae logits from confidence head
4000+ :param plddt_logits: plddt logits from confidence head
40074001 :param asym_id: [b n] asym_id of each residue
40084002 :param has_frame: [b n] has_frame of each residue
40094003 :param molecule_atom_lens: [b n] molecule atom lens
@@ -4035,7 +4029,7 @@ def compute_full_complex_metric(
40354029 atom_is_molecule_types = is_molecule_types .gather (1 , indices ) * valid_indices [..., None ]
40364030
40374031 confidence_score = self .compute_confidence_score (
4038- confidence_head_logits , asym_id , has_frame , multimer_mode = True
4032+ pae_logits , plddt_logits , asym_id , has_frame , molecule_atom_lens , multimer_mode = True
40394033 )
40404034 has_clash = self .compute_clash (
40414035 atom_pos ,
@@ -4062,9 +4056,11 @@ def compute_full_complex_metric(
40624056 @typecheck
40634057 def compute_single_chain_metric (
40644058 self ,
4065- confidence_head_logits : ConfidenceHeadLogits ,
4059+ pae_logits : Float ['b pae m m' ],
4060+ plddt_logits : Float ['b plddt m' ],
40664061 asym_id : Int ["b n" ],
4067- has_frame : Bool ["b n" ],
4062+ has_frame : Bool ["b n" ],
4063+ molecule_atom_lens : Int ["b n" ]
40684064 ) -> Float [" b" ]:
40694065
40704066 """Compute single chain metric.
@@ -4078,18 +4074,18 @@ def compute_single_chain_metric(
40784074 # Section 5.9.3.2
40794075
40804076 confidence_score = self .compute_confidence_score (
4081- confidence_head_logits , asym_id , has_frame , multimer_mode = False
4077+ pae_logits , plddt_logits , asym_id , has_frame , molecule_atom_lens , multimer_mode = False
40824078 )
40834079
4084- score = confidence_score .ptm
4085- return score
4080+ return confidence_score .ptm
40864081
40874082 @typecheck
40884083 def compute_interface_metric (
40894084 self ,
4090- confidence_head_logits : ConfidenceHeadLogits ,
4085+ pae_logits : Float [ 'b pae m m' ] ,
40914086 asym_id : Int ["b n" ],
4092- has_frame : Bool ["b n" ],
4087+ has_frame : Bool ["b n" ],
4088+ molecule_atom_lens : Int ['b n' ],
40934089 interface_chains : List ,
40944090 ) -> Float [" b" ]:
40954091 """Compute interface metric.
@@ -4116,7 +4112,7 @@ def compute_interface_metric(
41164112 chain_wise_iptm_mask ,
41174113 unique_chains ,
41184114 ) = self .compute_confidence_score .compute_ptm (
4119- confidence_head_logits . pae , asym_id , has_frame , compute_chain_wise_iptm = True
4115+ pae_logits , asym_id , has_frame , molecule_atom_lens , compute_chain_wise_iptm = True
41204116 )
41214117
41224118 # Section 5.9.3 equation 20
@@ -4141,7 +4137,7 @@ def compute_interface_metric(
41414137 @typecheck
41424138 def compute_modified_residue_score (
41434139 self ,
4144- confidence_head_logits : ConfidenceHeadLogits ,
4140+ plddt_logits : Float [ 'b plddt m' ] ,
41454141 atom_mask : Bool ["b m" ],
41464142 atom_is_modified_residue : Int ["b m" ],
41474143 ) -> Float [" b" ]:
@@ -4155,9 +4151,7 @@ def compute_modified_residue_score(
41554151
41564152 # Section 5.9.3.4
41574153
4158- plddt = self .compute_confidence_score .compute_plddt (
4159- confidence_head_logits .plddt ,
4160- )
4154+ plddt = self .compute_confidence_score .compute_plddt (plddt_logits )
41614155
41624156 mask = atom_is_modified_residue * atom_mask
41634157 plddt_mean = masked_average (plddt , mask , dim = - 1 , eps = self .eps )
@@ -5909,11 +5903,13 @@ def forward(
59095903 & valid_atom_indices_for_frame
59105904 )
59115905
5906+ align_error_mask = batch_repeat_interleave (align_error_mask , molecule_atom_lens )
5907+
59125908 # align error
59135909
59145910 align_error = self .compute_alignment_error (
5915- denoised_molecule_pos ,
5916- molecule_pos ,
5911+ denoised_atom_pos ,
5912+ atom_pos ,
59175913 pred_frames ,
59185914 frames ,
59195915 mask = align_error_mask ,
@@ -6043,8 +6039,7 @@ def forward(
60436039 # determine which mask to use for confidence head labels
60446040
60456041 label_mask = atom_mask
6046-
6047- label_pairwise_mask = to_pairwise_mask (mask )
6042+ label_pairwise_mask = to_pairwise_mask (atom_mask )
60486043
60496044 # cross entropy losses
60506045
@@ -6076,7 +6071,7 @@ def cross_entropy_with_weight(logits, labels, weight, ignore_index: int):
60766071 f"pde_labels shape { pde_labels .shape [- 1 ]} does not match "
60776072 f"ch_logits.pde shape { ch_logits .pde .shape [- 1 ]} "
60786073 )
6079- pde_labels = torch .where (label_pairwise_mask , pde_labels , ignore )
6074+ pde_labels = torch .where (to_pairwise_mask ( mask ) , pde_labels , ignore )
60806075 pde_loss = cross_entropy_with_weight (ch_logits .pde , pde_labels , confidence_weight , ignore )
60816076
60826077 if exists (plddt_labels ):
0 commit comments