@@ -3134,32 +3134,18 @@ def __init__(
31343134 @typecheck
31353135 def forward (
31363136 self ,
3137- pred_coords : Float ['b m_or_n 3' ],
3138- true_coords : Float ['b m_or_n 3' ],
3137+ pred_coords : Float ['b n 3' ],
3138+ true_coords : Float ['b n 3' ],
31393139 pred_frames : Float ['b n 3 3' ],
31403140 true_frames : Float ['b n 3 3' ],
3141- mask : Bool ['b m_or_n' ] | None = None ,
3142- molecule_atom_lens : Int ['b n' ] | None = None
3143- ) -> Float ['b m_or_n m_or_n' ]:
3141+ mask : Bool ['b n' ] | None = None ,
3142+ ) -> Float ['b n n' ]:
31443143 """
31453144 pred_coords: predicted coordinates
31463145 true_coords: true coordinates
31473146 pred_frames: predicted frames
31483147 true_frames: true frames
31493148 """
3150-
3151- # detect whether using atom or residue resolution
3152-
3153- is_atom_resolution = pred_coords .shape [1 ] != pred_frames .shape [1 ]
3154- assert not is_atom_resolution or exists (molecule_atom_lens ), '`molecule_atom_lens` must be passed in for atom resolution alignment error'
3155-
3156- if is_atom_resolution :
3157- pred_frames = batch_repeat_interleave (pred_frames , molecule_atom_lens )
3158- true_frames = batch_repeat_interleave (true_frames , molecule_atom_lens )
3159-
3160- if not exists (mask ) and exists (molecule_atom_lens ):
3161- mask = batch_repeat_interleave (molecule_atom_lens > 0 , molecule_atom_lens )
3162-
31633149 # to pairs
31643150
31653151 seq = pred_coords .shape [1 ]
@@ -3681,7 +3667,6 @@ def forward(
36813667 asym_id : Int ["b n" ],
36823668 has_frame : Bool ["b n" ],
36833669 ptm_residue_weight : Float ["b n" ] | None = None ,
3684- molecule_atom_lens : Int ["b n" ] | None = None ,
36853670 multimer_mode : bool = True ,
36863671 ) -> ConfidenceScore :
36873672 """Main function to compute confidence score.
@@ -3698,7 +3683,6 @@ def forward(
36983683 # Section 5.9.1 equation 17
36993684 ptm = self .compute_ptm (
37003685 confidence_head_logits .pae , asym_id , has_frame , ptm_residue_weight , interface = False ,
3701- molecule_atom_lens = molecule_atom_lens ,
37023686 )
37033687
37043688 iptm = None
@@ -3707,7 +3691,6 @@ def forward(
37073691 # Section 5.9.2 equation 18
37083692 iptm = self .compute_ptm (
37093693 confidence_head_logits .pae , asym_id , has_frame , ptm_residue_weight , interface = True ,
3710- molecule_atom_lens = molecule_atom_lens ,
37113694 )
37123695
37133696 confidence_score = ConfidenceScore (plddt = plddt , ptm = ptm , iptm = iptm )
@@ -3735,11 +3718,10 @@ def compute_plddt(
37353718 @typecheck
37363719 def compute_ptm (
37373720 self ,
3738- logits : Float ["b pae m_or_n m_or_n " ],
3721+ pae_logits : Float ["b pae n n " ],
37393722 asym_id : Int ["b n" ],
37403723 has_frame : Bool ["b n" ],
37413724 residue_weights : Float ["b n" ] | None = None ,
3742- molecule_atom_lens : Int ["b n" ] | None = None ,
37433725 interface : bool = False ,
37443726 compute_chain_wise_iptm : bool = False ,
37453727 ) -> Float [" b" ] | Tuple [Float ["b chains chains" ], Bool ["b chains chains" ], Int ["b chains" ]]:
@@ -3753,25 +3735,15 @@ def compute_ptm(
37533735 :param interface: bool
37543736 :param compute_chain_wise_iptm: bool
37553737 :return: pTM
3756- """
3757-
3758- is_atom_resolution = logits .shape [- 1 ] != asym_id .shape [- 1 ]
3759- assert not is_atom_resolution or exists (molecule_atom_lens ), '`molecule_atom_lens` must be passed in for atom resolution pTM'
3760-
3761- if is_atom_resolution :
3762- asym_id = batch_repeat_interleave (asym_id , molecule_atom_lens )
3763- has_frame = batch_repeat_interleave (has_frame , molecule_atom_lens )
3764- if exists (residue_weights ):
3765- residue_weights = batch_repeat_interleave (residue_weights , molecule_atom_lens )
3766-
3738+ """
37673739 if not exists (residue_weights ):
37683740 residue_weights = torch .ones_like (has_frame )
37693741
37703742 residue_weights = residue_weights * has_frame
37713743
3772- num_batch = logits .shape [ 0 ]
3773- num_res = logits . shape [ - 1 ]
3774- logits = rearrange (logits , "b c i j -> b i j c" )
3744+ num_batch , * _ , num_res , device = * pae_logits .shape , pae_logits . device
3745+
3746+ pae_logits = rearrange (pae_logits , "b c i j -> b i j c" )
37753747
37763748 bin_centers = self ._calculate_bin_centers (self .pae_breaks )
37773749
@@ -3788,7 +3760,7 @@ def compute_ptm(
37883760 tm_per_bin = 1.0 / (1 + torch .square (bin_centers [None , :]) / torch .square (d0 [..., None ]))
37893761
37903762 # Convert logits to probs.
3791- probs = F .softmax (logits , dim = - 1 )
3763+ probs = F .softmax (pae_logits , dim = - 1 )
37923764
37933765 # E_distances tm(distance).
37943766 predicted_tm_term = einsum (probs , tm_per_bin , "b i j pae, b pae -> b i j " )
@@ -3801,7 +3773,7 @@ def compute_ptm(
38013773 max_chains = max (len (chains ) for chains in unique_chains )
38023774
38033775 chain_wise_iptm = torch .zeros (
3804- (num_batch , max_chains , max_chains ), device = logits . device
3776+ (num_batch , max_chains , max_chains ), device = device
38053777 )
38063778 chain_wise_iptm_mask = torch .zeros_like (chain_wise_iptm ).bool ()
38073779
@@ -3837,7 +3809,7 @@ def compute_ptm(
38373809 return chain_wise_iptm , chain_wise_iptm_mask , torch .tensor (unique_chains )
38383810
38393811 else :
3840- pair_mask = torch .ones (size = (num_batch , num_res , num_res ), device = logits . device ).bool ()
3812+ pair_mask = torch .ones (size = (num_batch , num_res , num_res ), device = device ).bool ()
38413813 if interface :
38423814 pair_mask *= asym_id [:, :, None ] != asym_id [:, None , :]
38433815
@@ -3857,13 +3829,14 @@ def compute_ptm(
38573829 @typecheck
38583830 def compute_pde (
38593831 self ,
3860- logits : Float ["b pde n n" ],
3832+ pde_logits : Float ["b pde n n" ],
38613833 tok_repr_atm_mask : Bool ["b n" ],
38623834 ) -> Float ["b n n" ]:
38633835 """Compute PDE from logits."""
3864- logits = rearrange (logits , "b pde i j -> b i j pde" )
3836+
3837+ pde_logits = rearrange (pde_logits , "b pde i j -> b i j pde" )
38653838 bin_centers = self ._calculate_bin_centers (self .pde_breaks )
3866- probs = F .softmax (logits , dim = - 1 )
3839+ probs = F .softmax (pde_logits , dim = - 1 )
38673840
38683841 pde = einsum (probs , bin_centers , "b i j pde, pde -> b i j" )
38693842
@@ -5968,7 +5941,6 @@ def forward(
59685941 pred_frames ,
59695942 frames ,
59705943 mask = align_error_mask ,
5971- molecule_atom_lens = molecule_atom_lens ,
59725944 )
59735945
59745946 # calculate pae labels as alignment error binned to 64 (0 - 32A) (TODO: double-check correctness of `distance_to_bins`'s bin assignments)
0 commit comments