@@ -3518,6 +3518,46 @@ def get_cid_molecule_type(
35183518 return molecule_type
35193519
35203520class ComputeModelSelectionScore (Module ):
3521+ INITIAL_TRAINING_DICT = {
3522+ 'protein-protein' : {'interface' : 20 , 'intra-chain' : 20 },
3523+ 'DNA-protein' : {'interface' : 10 },
3524+ 'RNA-protein' : {'interface' : 10 },
3525+
3526+ 'ligand-protein' : {'interface' : 10 },
3527+ 'DNA-ligand' : {'interface' : 5 },
3528+ 'RNA-ligand' : {'interface' : 5 },
3529+
3530+ 'DNA-DNA' : {'intra-chain' : 4 },
3531+ 'RNA-RNA' : {'intra-chain' : 16 },
3532+ 'ligand-ligand' : {'intra-chain' : 20 },
3533+ 'metal_ion-metal_ion' : {'intra-chain' : 10 },
3534+ 'unresolved' : {'unresolved' : 10 }
3535+ }
3536+
3537+ FINETUNING_DICT = {
3538+ 'protein-protein' : {'interface' : 20 , 'intra-chain' : 20 },
3539+ 'DNA-protein' : {'interface' : 10 },
3540+ 'RNA-protein' : {'interface' : 2 },
3541+
3542+ 'ligand-protein' : {'interface' : 10 },
3543+ 'DNA-ligand' : {'interface' : 5 },
3544+ 'RNA-ligand' : {'interface' : 2 },
3545+
3546+ 'DNA-DNA' : {'intra-chain' : 4 },
3547+ 'RNA-RNA' : {'intra-chain' : 16 },
3548+ 'ligand-ligand' : {'intra-chain' : 20 },
3549+ 'metal_ion-metal_ion' : {'intra-chain' : 0 },
3550+
3551+ 'unresolved' : {'unresolved' : 10 }
3552+ }
3553+
3554+ TYPE_MAPPING = {
3555+ IS_PROTEIN : 'protein' ,
3556+ IS_DNA : 'DNA' ,
3557+ IS_RNA : 'RNA' ,
3558+ IS_LIGAND : 'ligand' ,
3559+ IS_METAL_ION : 'metal_ion'
3560+ }
35213561
35223562 @typecheck
35233563 def __init__ (
@@ -3526,15 +3566,20 @@ def __init__(
35263566 dist_breaks : Float [' dist_break' ] = torch .linspace (2.3125 ,21.6875 ,63 ,),
35273567 nucleic_acid_cutoff : float = 30.0 ,
35283568 other_cutoff : float = 15.0 ,
3529- contact_mask_threshold : float = 8.0
3569+ contact_mask_threshold : float = 8.0 ,
3570+ is_fine_tuning : bool = False ,
3571+ weight_dict_config : dict = None
35303572 ):
35313573
35323574 super ().__init__ ()
35333575 self .compute_confidence_score = ComputeConfidenceScore (eps = eps )
3576+
35343577 self .eps = eps
35353578 self .nucleic_acid_cutoff = nucleic_acid_cutoff
35363579 self .other_cutoff = other_cutoff
35373580 self .contact_mask_threshold = contact_mask_threshold
3581+ self .is_fine_tuning = is_fine_tuning
3582+ self .weight_dict_config = weight_dict_config
35383583
35393584 self .register_buffer ('dist_breaks' , dist_breaks )
35403585
@@ -3598,7 +3643,6 @@ def compute_lddt(
35983643 # Compute distance difference for all pairs of atoms
35993644 dist_diff = torch .abs (true_dists - pred_dists )
36003645
3601-
36023646 lddt = (
36033647 ((0.5 - dist_diff ) >= 0 ).float () +
36043648 ((1.0 - dist_diff ) >= 0 ).float () +
@@ -3653,9 +3697,7 @@ def compute_chain_pair_lddt(
36533697
36543698 if asym_mask_a .ndim == 1 :
36553699 args = [asym_mask_a , asym_mask_b , pred_coords , true_coords , is_molecule_types , coords_mask ]
3656- args = list (
3657- map (lambda x : x .unsqueeze (0 ), args )
3658- )
3700+ args = [x .unsqueeze (0 ) for x in args ]
36593701 asym_mask_a , asym_mask_b , pred_coords , true_coords , is_molecule_types , coords_mask = args
36603702
36613703
@@ -3677,59 +3719,18 @@ def get_lddt_weight(
36773719 type_chain_a ,
36783720 type_chain_b ,
36793721 lddt_type : Literal ['interface' , 'intra-chain' , 'unresolved' ],
3680- is_fine_tuning : bool = False ,
3722+ is_fine_tuning : bool = None ,
36813723 ):
3724+ is_fine_tuning = default (is_fine_tuning , self .is_fine_tuning )
36823725
3683- type_mapping = {
3684- IS_PROTEIN : 'protein' ,
3685- IS_DNA : 'DNA' ,
3686- IS_RNA : 'RNA' ,
3687- IS_LIGAND : 'ligand' ,
3688- IS_METAL_ION : 'metal_ion'
3689- }
3690-
3691- initial_training_dict = {
3692- 'protein-protein' : {'interface' : 20 , 'intra-chain' : 20 },
3693- 'DNA-protein' : {'interface' : 10 },
3694- 'RNA-protein' : {'interface' : 10 },
3695-
3696- 'ligand-protein' : {'interface' : 10 },
3697- 'DNA-ligand' : {'interface' : 5 },
3698- 'RNA-ligand' : {'interface' : 5 },
3699-
3700- 'DNA-DNA' : {'intra-chain' : 4 },
3701- 'RNA-RNA' : {'intra-chain' : 16 },
3702- 'ligand-ligand' : {'intra-chain' : 20 },
3703- 'metal_ion-metal_ion' : {'intra-chain' : 10 },
3704-
3705- 'unresolved' : {'unresolved' : 10 }
3706- }
3707-
3708- fine_tuning_dict = {
3709- 'protein-protein' : {'interface' : 20 , 'intra-chain' : 20 },
3710- 'DNA-protein' : {'interface' : 10 },
3711- 'RNA-protein' : {'interface' : 2 },
3712-
3713- 'ligand-protein' : {'interface' : 10 },
3714- 'DNA-ligand' : {'interface' : 5 },
3715- 'RNA-ligand' : {'interface' : 2 },
3716-
3717- 'DNA-DNA' : {'intra-chain' : 4 },
3718- 'RNA-RNA' : {'intra-chain' : 16 },
3719- 'ligand-ligand' : {'intra-chain' : 20 },
3720- 'metal_ion-metal_ion' : {'intra-chain' : 0 },
3721-
3722- 'unresolved' : {'unresolved' : 10 }
3723- }
3724-
3725- weight_dict = fine_tuning_dict if is_fine_tuning else initial_training_dict
3726+ weight_dict = default (self .weight_dict_config , self .FINETUNING_DICT if is_fine_tuning else self .INITIAL_TRAINING_DICT )
37263727
37273728 if lddt_type == 'unresolved' :
37283729 weight = weight_dict .get (lddt_type , {}).get (lddt_type , None )
37293730 assert weight
37303731 return weight
37313732
3732- interface_type = sorted ([type_mapping [type_chain_a ], type_mapping [type_chain_b ]])
3733+ interface_type = sorted ([self . TYPE_MAPPING [type_chain_a ], self . TYPE_MAPPING [type_chain_b ]])
37333734 interface_type = '-' .join (interface_type )
37343735 weight = weight_dict .get (interface_type , {}).get (lddt_type , None )
37353736 assert weight , f"Weight not found for { interface_type } { lddt_type } "
@@ -3748,8 +3749,9 @@ def compute_weighted_lddt(
37483749 molecule_atom_lens : Int ['b n' ],
37493750 # additional input
37503751 chains_list : List [Tuple [int , int ] | Tuple [int ]],
3751- is_fine_tuning : bool = False ,
3752+ is_fine_tuning : bool = None ,
37523753 ):
3754+ is_fine_tuning = default (is_fine_tuning , self .is_fine_tuning )
37533755
37543756 device = pred_coords .device
37553757 batch_size = pred_coords .shape [0 ]
0 commit comments