@@ -3492,12 +3492,14 @@ def compute_modified_residue_score(
34923492 return plddt_mean
34933493
34943494# model selection
3495+
3496+ @typecheck
34953497def get_cid_molecule_type (
34963498 cid : int ,
34973499 asym_id : Int [' n' ],
3498- is_molecule_types : Bool ['n {IS_MOLECULE_TYPES}' ],
3500+ is_molecule_types : Bool [f 'n { IS_MOLECULE_TYPES } ' ],
34993501 return_one_hot : bool = False ,
3500- ) -> int | Bool [' {IS_MOLECULE_TYPES}' ]:
3502+ ) -> int | Bool [f ' { IS_MOLECULE_TYPES } ' ]:
35013503 """
35023504
35033505 get the molecule type for where asym_id == cid
@@ -3517,6 +3519,7 @@ def get_cid_molecule_type(
35173519
35183520class ComputeModelSelectionScore (Module ):
35193521
3522+ @typecheck
35203523 def __init__ (
35213524 self ,
35223525 eps : float = 1e-8 ,
@@ -3535,6 +3538,7 @@ def __init__(
35353538
35363539 self .register_buffer ('dist_breaks' , dist_breaks )
35373540
3541+ @typecheck
35383542 def compute_gpde (
35393543 self ,
35403544 pde_logits : Float ['b pde n n' ],
@@ -3570,6 +3574,7 @@ def compute_gpde(
35703574
35713575 return gpde
35723576
3577+ @typecheck
35733578 def compute_lddt (
35743579 self ,
35753580 pred_coords : Float ['b m 3' ],
@@ -3628,13 +3633,14 @@ def compute_lddt(
36283633
36293634 return lddt_mean
36303635
3636+ @typecheck
36313637 def compute_chain_pair_lddt (
36323638 self ,
36333639 asym_mask_a : Bool ['b m' ] | Bool [' m' ],
36343640 asym_mask_b : Bool ['b m' ] | Bool [' m' ],
36353641 pred_coords : Float ['b m 3' ] | Float ['m 3' ],
36363642 true_coords : Float ['b m 3' ] | Float ['m 3' ],
3637- is_molecule_types : Int [ 'b m {IS_MOLECULE_TYPES}' ] | Int [ 'm {IS_MOLECULE_TYPES}' ],
3643+ is_molecule_types : Bool [ f 'b m { IS_MOLECULE_TYPES } ' ] | Bool [ f 'm { IS_MOLECULE_TYPES } ' ],
36383644 coords_mask : Bool ['b m' ] | Bool [' m' ] | None = None ,
36393645 ) -> Float [' b' ]:
36403646 """
@@ -3665,6 +3671,7 @@ def compute_chain_pair_lddt(
36653671
36663672 return lddt
36673673
3674+ @typecheck
36683675 def get_lddt_weight (
36693676 self ,
36703677 type_chain_a ,
@@ -3718,7 +3725,7 @@ def get_lddt_weight(
37183725 weight_dict = fine_tuning_dict if is_fine_tuning else initial_training_dict
37193726
37203727 if lddt_type == 'unresolved' :
3721- weight = weight_dict .get (lddt_type , None ).get (lddt_type , None )
3728+ weight = weight_dict .get (lddt_type , {} ).get (lddt_type , None )
37223729 assert weight
37233730 return weight
37243731
@@ -3728,6 +3735,7 @@ def get_lddt_weight(
37283735 assert weight , f"Weight not found for { interface_type } { lddt_type } "
37293736 return weight
37303737
3738+ @typecheck
37313739 def compute_weighted_lddt (
37323740 self ,
37333741 # atom level input
@@ -3736,7 +3744,7 @@ def compute_weighted_lddt(
37363744 atom_mask : Bool ['b m' ] | None ,
37373745 # token level input
37383746 asym_id : Int ['b n' ],
3739- is_molecule_types : Bool ['b n {IS_MOLECULE_TYPES}' ],
3747+ is_molecule_types : Bool [f 'b n { IS_MOLECULE_TYPES } ' ],
37403748 molecule_atom_lens : Int ['b n' ],
37413749 # additional input
37423750 chains_list : List [Tuple [int , int ] | Tuple [int ]],
0 commit comments