Skip to content

Commit f9ba6bd

Browse files
authored
Add full constraints set (#268)
* Update inputs.py * Update test_dataloading.py * Update inputs.py * Update weighted_pdb_sampler.py * Update inputs.py * Update plm.py * Update alphafold3.py * Update template_parsing.py * Update inputs.py * Update inputs.py * Update alphafold3.py * Update alphafold3.py
1 parent 010a989 commit f9ba6bd

File tree

6 files changed

+307
-56
lines changed

6 files changed

+307
-56
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,17 @@
5454
)
5555

5656
from alphafold3_pytorch.inputs import (
57+
CONSTRAINT_DIMS,
58+
CONSTRAINTS,
59+
CONSTRAINTS_MASK_VALUE,
5760
IS_MOLECULE_TYPES,
5861
IS_PROTEIN_INDEX,
5962
IS_DNA_INDEX,
6063
IS_RNA_INDEX,
6164
IS_LIGAND_INDEX,
6265
IS_METAL_ION_INDEX,
6366
IS_BIOMOLECULE_INDICES,
67+
IS_NON_PROTEIN_INDICES,
6468
IS_PROTEIN,
6569
IS_DNA,
6670
IS_RNA,
@@ -5954,7 +5958,7 @@ def __init__(
59545958
pdb_training_set=True,
59555959
plm_embeddings: PLMEmbedding | tuple[PLMEmbedding, ...] | None = None,
59565960
plm_kwargs: dict | tuple[dict, ...] | None = None,
5957-
constraint_embeddings: int | None = None,
5961+
constraints: List[CONSTRAINTS] | None = None,
59585962
):
59595963
super().__init__()
59605964

@@ -5983,10 +5987,18 @@ def __init__(
59835987

59845988
# optional pairwise token constraint embeddings
59855989

5986-
self.constraint_embeddings = constraint_embeddings
5990+
self.constraints = constraints
59875991

5988-
if exists(constraint_embeddings):
5989-
self.constraint_embeds = LinearNoBias(constraint_embeddings, dim_pairwise)
5992+
if exists(constraints):
5993+
self.constraint_embeds = nn.ModuleList(
5994+
[
5995+
LinearNoBias(CONSTRAINT_DIMS[constraint], dim_pairwise)
5996+
for constraint in constraints
5997+
]
5998+
)
5999+
self.learnable_constraint_masks = nn.ParameterList(
6000+
[nn.Parameter(torch.randn(1)) for _ in constraints]
6001+
)
59906002

59916003
# residue or nucleotide modifications
59926004

@@ -6538,21 +6550,30 @@ def forward(
65386550

65396551
# handle maybe pairwise token constraint embeddings
65406552

6541-
if exists(self.constraint_embeddings):
6553+
if exists(self.constraints):
65426554
assert exists(
65436555
token_constraints
65446556
), "`token_constraints` must be provided to use constraint embeddings."
65456557

6546-
pairwise_constraint_embeds = self.constraint_embeds(token_constraints)
6547-
pairwise_init = pairwise_init + pairwise_constraint_embeds
6558+
for i, constraint in enumerate(self.constraints):
6559+
constraint_slice = slice(i, i + CONSTRAINT_DIMS[constraint])
6560+
6561+
token_constraint = torch.where(
6562+
# replace fixed constraint mask values with learnable mask
6563+
token_constraints[..., constraint_slice] == CONSTRAINTS_MASK_VALUE,
6564+
self.learnable_constraint_masks[i],
6565+
token_constraints[..., constraint_slice],
6566+
)
6567+
6568+
pairwise_init = pairwise_init + self.constraint_embeds[i](token_constraint)
65486569

65496570
# handle maybe protein language model (PLM) embeddings
65506571

65516572
if exists(self.plms):
65526573
molecule_aa_ids = torch.where(
6553-
molecule_ids < 0,
6554-
NUM_HUMAN_AMINO_ACIDS,
6555-
molecule_ids.clamp(max=NUM_HUMAN_AMINO_ACIDS),
6574+
is_molecule_types[..., IS_NON_PROTEIN_INDICES].any(dim=-1),
6575+
-1,
6576+
molecule_ids,
65566577
)
65576578

65586579
plm_embeds = [plm(molecule_aa_ids) for plm in self.plms]

alphafold3_pytorch/data/template_parsing.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,8 @@ def parse_m8(
121121
template_release_date = extract_mmcif_metadata_field(
122122
template_mmcif_object, "release_date"
123123
)
124-
if not (
125-
exists(template_cutoff_date)
126-
and datetime.strptime(template_release_date, "%Y-%m-%d") <= template_cutoff_date
127-
):
124+
if exists(template_cutoff_date) and datetime.strptime(template_release_date, "%Y-%m-%d") > template_cutoff_date:
128125
continue
129-
elif not_exists(template_cutoff_date):
130-
pass
131126
template_biomol = _from_mmcif_object(
132127
template_mmcif_object, chain_ids=set(template_chain)
133128
)

alphafold3_pytorch/data/weighted_pdb_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __init__(
191191
alpha_prot: float = 3.0,
192192
alpha_nuc: float = 3.0,
193193
alpha_ligand: float = 1.0,
194-
pdb_ids_to_skip: List[str] = [],
194+
pdb_ids_to_skip: List[str] | None = None,
195195
pdb_ids_to_keep: list[str] | None = None,
196196
):
197197
# Load chain and interface mappings
@@ -210,7 +210,7 @@ def __init__(
210210
interface_mapping = pl.read_csv(interface_mapping_path)
211211

212212
# Filter out unwanted PDB IDs
213-
if len(pdb_ids_to_skip) > 0:
213+
if exists(pdb_ids_to_skip) and len(pdb_ids_to_skip) > 0:
214214
chain_mapping = chain_mapping.filter(pl.col("pdb_id").is_in(pdb_ids_to_skip).not_())
215215
interface_mapping = interface_mapping.filter(
216216
pl.col("pdb_id").is_in(pdb_ids_to_skip).not_()

0 commit comments

Comments
 (0)