Skip to content

Commit 45d887e

Browse files
committed
allow for overriding of weight dict config when computing model selection score
1 parent e09cd46 commit 45d887e

File tree

2 files changed

+54
-52
lines changed

2 files changed

+54
-52
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3518,6 +3518,46 @@ def get_cid_molecule_type(
35183518
return molecule_type
35193519

35203520
class ComputeModelSelectionScore(Module):
3521+
INITIAL_TRAINING_DICT = {
3522+
'protein-protein': {'interface': 20, 'intra-chain': 20},
3523+
'DNA-protein': {'interface': 10},
3524+
'RNA-protein': {'interface': 10},
3525+
3526+
'ligand-protein': {'interface': 10},
3527+
'DNA-ligand': {'interface': 5},
3528+
'RNA-ligand': {'interface': 5},
3529+
3530+
'DNA-DNA': {'intra-chain': 4},
3531+
'RNA-RNA': {'intra-chain': 16},
3532+
'ligand-ligand': {'intra-chain': 20},
3533+
'metal_ion-metal_ion': {'intra-chain': 10},
3534+
'unresolved': {'unresolved': 10}
3535+
}
3536+
3537+
FINETUNING_DICT = {
3538+
'protein-protein': {'interface': 20, 'intra-chain': 20},
3539+
'DNA-protein': {'interface': 10},
3540+
'RNA-protein': {'interface': 2},
3541+
3542+
'ligand-protein': {'interface': 10},
3543+
'DNA-ligand': {'interface': 5},
3544+
'RNA-ligand': {'interface': 2},
3545+
3546+
'DNA-DNA': {'intra-chain': 4},
3547+
'RNA-RNA': {'intra-chain': 16},
3548+
'ligand-ligand': {'intra-chain': 20},
3549+
'metal_ion-metal_ion': {'intra-chain': 0},
3550+
3551+
'unresolved': {'unresolved': 10}
3552+
}
3553+
3554+
TYPE_MAPPING = {
3555+
IS_PROTEIN: 'protein',
3556+
IS_DNA: 'DNA',
3557+
IS_RNA: 'RNA',
3558+
IS_LIGAND: 'ligand',
3559+
IS_METAL_ION: 'metal_ion'
3560+
}
35213561

35223562
@typecheck
35233563
def __init__(
@@ -3526,15 +3566,20 @@ def __init__(
35263566
dist_breaks: Float[' dist_break'] = torch.linspace(2.3125,21.6875,63,),
35273567
nucleic_acid_cutoff: float = 30.0,
35283568
other_cutoff: float = 15.0,
3529-
contact_mask_threshold: float = 8.0
3569+
contact_mask_threshold: float = 8.0,
3570+
is_fine_tuning: bool = False,
3571+
weight_dict_config: dict = None
35303572
):
35313573

35323574
super().__init__()
35333575
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
3576+
35343577
self.eps = eps
35353578
self.nucleic_acid_cutoff = nucleic_acid_cutoff
35363579
self.other_cutoff = other_cutoff
35373580
self.contact_mask_threshold = contact_mask_threshold
3581+
self.is_fine_tuning = is_fine_tuning
3582+
self.weight_dict_config = weight_dict_config
35383583

35393584
self.register_buffer('dist_breaks', dist_breaks)
35403585

@@ -3598,7 +3643,6 @@ def compute_lddt(
35983643
# Compute distance difference for all pairs of atoms
35993644
dist_diff = torch.abs(true_dists - pred_dists)
36003645

3601-
36023646
lddt = (
36033647
((0.5 - dist_diff) >=0).float() +
36043648
((1.0 - dist_diff) >=0).float() +
@@ -3653,9 +3697,7 @@ def compute_chain_pair_lddt(
36533697

36543698
if asym_mask_a.ndim == 1:
36553699
args = [asym_mask_a, asym_mask_b, pred_coords, true_coords, is_molecule_types, coords_mask ]
3656-
args = list(
3657-
map(lambda x: x.unsqueeze(0), args)
3658-
)
3700+
args = [x.unsqueeze(0) for x in args]
36593701
asym_mask_a, asym_mask_b, pred_coords, true_coords, is_molecule_types, coords_mask = args
36603702

36613703

@@ -3677,59 +3719,18 @@ def get_lddt_weight(
36773719
type_chain_a,
36783720
type_chain_b,
36793721
lddt_type: Literal['interface', 'intra-chain', 'unresolved'],
3680-
is_fine_tuning: bool = False,
3722+
is_fine_tuning: bool = None,
36813723
):
3724+
is_fine_tuning = default(is_fine_tuning, self.is_fine_tuning)
36823725

3683-
type_mapping = {
3684-
IS_PROTEIN: 'protein',
3685-
IS_DNA: 'DNA',
3686-
IS_RNA: 'RNA',
3687-
IS_LIGAND: 'ligand',
3688-
IS_METAL_ION: 'metal_ion'
3689-
}
3690-
3691-
initial_training_dict = {
3692-
'protein-protein': {'interface': 20, 'intra-chain': 20},
3693-
'DNA-protein': {'interface': 10},
3694-
'RNA-protein': {'interface': 10},
3695-
3696-
'ligand-protein': {'interface': 10},
3697-
'DNA-ligand': {'interface': 5},
3698-
'RNA-ligand': {'interface': 5},
3699-
3700-
'DNA-DNA': {'intra-chain': 4},
3701-
'RNA-RNA': {'intra-chain': 16},
3702-
'ligand-ligand': {'intra-chain': 20},
3703-
'metal_ion-metal_ion': {'intra-chain': 10},
3704-
3705-
'unresolved': {'unresolved': 10}
3706-
}
3707-
3708-
fine_tuning_dict = {
3709-
'protein-protein': {'interface': 20, 'intra-chain': 20},
3710-
'DNA-protein': {'interface': 10},
3711-
'RNA-protein': {'interface': 2},
3712-
3713-
'ligand-protein': {'interface': 10},
3714-
'DNA-ligand': {'interface': 5},
3715-
'RNA-ligand': {'interface': 2},
3716-
3717-
'DNA-DNA': {'intra-chain': 4},
3718-
'RNA-RNA': {'intra-chain': 16},
3719-
'ligand-ligand': {'intra-chain': 20},
3720-
'metal_ion-metal_ion': {'intra-chain': 0},
3721-
3722-
'unresolved': {'unresolved': 10}
3723-
}
3724-
3725-
weight_dict = fine_tuning_dict if is_fine_tuning else initial_training_dict
3726+
weight_dict = default(self.weight_dict_config, self.FINETUNING_DICT if is_fine_tuning else self.INITIAL_TRAINING_DICT)
37263727

37273728
if lddt_type == 'unresolved':
37283729
weight = weight_dict.get(lddt_type, {}).get(lddt_type, None)
37293730
assert weight
37303731
return weight
37313732

3732-
interface_type = sorted([type_mapping[type_chain_a], type_mapping[type_chain_b]])
3733+
interface_type = sorted([self.TYPE_MAPPING[type_chain_a], self.TYPE_MAPPING[type_chain_b]])
37333734
interface_type = '-'.join(interface_type)
37343735
weight = weight_dict.get(interface_type, {}).get(lddt_type, None)
37353736
assert weight, f"Weight not found for {interface_type} {lddt_type}"
@@ -3748,8 +3749,9 @@ def compute_weighted_lddt(
37483749
molecule_atom_lens: Int['b n'],
37493750
# additional input
37503751
chains_list: List[Tuple[int, int] | Tuple[int]],
3751-
is_fine_tuning: bool = False,
3752+
is_fine_tuning: bool = None,
37523753
):
3754+
is_fine_tuning = default(is_fine_tuning, self.is_fine_tuning)
37533755

37543756
device = pred_coords.device
37553757
batch_size = pred_coords.shape[0]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.52"
3+
version = "0.2.53"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)