@@ -292,7 +292,7 @@ def mean_pool_with_lens(
292292def repeat_consecutive_with_lens (
293293 feats : Float ['b n ...' ] | Bool ['b n ...' ] | Bool ['b n' ] | Int ['b n' ],
294294 lens : Int ['b n' ],
295- mask_value : Optional [ float | int | bool ] = None ,
295+ mask_value : float | int | bool | None = None ,
296296) -> Float ['b m ...' ] | Bool ['b m ...' ] | Bool ['b m' ] | Int ['b m' ]:
297297
298298 device , dtype = feats .device , feats .dtype
@@ -3227,12 +3227,11 @@ def compute_ptm(
32273227 def compute_pde (
32283228 self ,
32293229 logits : Float ['b pde n n' ],
3230- breaks : Float [' pde_break' ],
32313230 tok_repr_atm_mask : Bool [' b n' ],
32323231 )-> Float [' b n n' ]:
32333232
32343233 logits = rearrange (logits , 'b pde i j -> b i j pde' )
3235- bin_centers = self ._calculate_bin_centers (breaks . to ( logits . device ) )
3234+ bin_centers = self ._calculate_bin_centers (self . pde_breaks )
32363235 probs = F .softmax (logits , dim = - 1 )
32373236
32383237 pde = einsum (probs , bin_centers , 'b i j pde, pde -> b i j ' )
@@ -3495,7 +3494,7 @@ def compute_modified_residue_score(
34953494# model selection
34963495def get_cid_molecule_type (
34973496 cid : int ,
3498- asym_id : Int ['n' ],
3497+ asym_id : Int [' n' ],
34993498 is_molecule_types : Bool ['n {IS_MOLECULE_TYPES}' ],
35003499 return_one_hot : bool = False ,
35013500 ) -> int | Bool [' {IS_MOLECULE_TYPES}' ]:
@@ -3505,17 +3504,15 @@ def get_cid_molecule_type(
35053504 """
35063505
35073506 cid_is_molecule_types = is_molecule_types [asym_id == cid ]
3508- valid = torch .all (
3509- einx .equal ('b i, i -> b i' ,
3510- cid_is_molecule_types ,
3511- cid_is_molecule_types [0 ])
3512- )
3507+ molecule_type , rest_molecule_type = cid_is_molecule_types [0 ], cid_is_molecule_types [1 :]
3508+
3509+ valid = einx .equal ('b i, i -> b i' , rest_molecule_type , molecule_type ).all ()
3510+
35133511 assert valid , f"Ambiguous molecule types for chain { cid } "
35143512
3515- if return_one_hot :
3516- molecule_type = cid_is_molecule_types [0 ]
3517- else :
3518- molecule_type = cid_is_molecule_types [0 ].int ().argmax ().item ()
3513+ if not return_one_hot :
3514+ molecule_type = molecule_type .int ().argmax ().item ()
3515+
35193516 return molecule_type
35203517
35213518class ComputeModelSelectionScore (Module ):
@@ -3525,14 +3522,17 @@ def __init__(
35253522 eps : float = 1e-8 ,
35263523 dist_breaks : Float [' dist_break' ] = torch .linspace (2.3125 ,21.6875 ,63 ,),
35273524 nucleic_acid_cutoff : float = 30.0 ,
3528- other_cutoff : float = 15.0
3525+ other_cutoff : float = 15.0 ,
3526+ contact_mask_threshold : float = 8.0
35293527 ):
35303528
35313529 super ().__init__ ()
35323530 self .compute_confidence_score = ComputeConfidenceScore (eps = eps )
35333531 self .eps = eps
35343532 self .nucleic_acid_cutoff = nucleic_acid_cutoff
35353533 self .other_cutoff = other_cutoff
3534+ self .contact_mask_threshold = contact_mask_threshold
3535+
35363536 self .register_buffer ('dist_breaks' , dist_breaks )
35373537
35383538 def compute_gpde (
@@ -3548,13 +3548,14 @@ def compute_gpde(
35483548 tok_repr_atm_mask: [b n] true if token representation atoms exists
35493549 """
35503550
3551- pde = self .compute_confidence_score .compute_pde (
3552- pde_logits , self .compute_confidence_score .pde_breaks , tok_repr_atm_mask )
3551+ pde = self .compute_confidence_score .compute_pde (pde_logits , tok_repr_atm_mask )
35533552
35543553 dist_logits = rearrange (dist_logits , 'b dist i j -> b i j dist' )
35553554 dist_probs = F .softmax (dist_logits , dim = - 1 )
3556- contact_mask = dist_breaks < 8.0
3557- contact_mask = torch .cat ([contact_mask , torch .zeros ([1 ], device = dist_logits .device )]).bool ()
3555+
3556+ contact_mask = dist_breaks < self .contact_mask_threshold
3557+ contact_mask = F .pad (contact_mask , (0 , 1 ), value = True )
3558+
35583559 contact_prob = einx .where (
35593560 ' dist, b i j dist, -> b i j dist' ,
35603561 contact_mask , dist_probs , 0.
@@ -3577,7 +3578,7 @@ def compute_lddt(
35773578 is_rna : Bool ['b m' ],
35783579 pairwise_mask : Bool ['b m m' ],
35793580 coords_mask : Bool ['b m' ] | None = None ,
3580- ) -> Float ['b' ]:
3581+ ) -> Float [' b' ]:
35813582 """
35823583 pred_coords: predicted coordinates
35833584 true_coords: true coordinates
@@ -3635,7 +3636,7 @@ def compute_chain_pair_lddt(
36353636 true_coords : Float ['b m 3' ] | Float ['m 3' ],
36363637 is_molecule_types : Int ['b m {IS_MOLECULE_TYPES}' ] | Int ['m {IS_MOLECULE_TYPES}' ],
36373638 coords_mask : Bool ['b m' ] | Bool [' m' ] | None = None ,
3638- ) -> Float ['b' ]:
3639+ ) -> Float [' b' ]:
36393640 """
36403641
36413642 plddt between atoms maked by asym_mask_a and asym_mask_b
0 commit comments