Skip to content

Commit 5b1e921

Browse files
authored
Add new embeddings (#252)
* Update alphafold3.py * Update inputs.py * Update pyproject.toml * Update inputs.py * Update test_af3.py
1 parent 08a546f commit 5b1e921

File tree

4 files changed

+267
-6
lines changed

4 files changed

+267
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import esm
34
import random
45
import sh
56
from math import pi, sqrt
@@ -63,6 +64,7 @@
6364
IS_RNA,
6465
IS_LIGAND,
6566
IS_METAL_ION,
67+
NUM_HUMAN_AMINO_ACIDS,
6668
NUM_MOLECULE_IDS,
6769
NUM_MSA_ONE_HOT,
6870
DEFAULT_NUM_MOLECULE_MODS,
@@ -136,6 +138,8 @@
136138
dmi - feature dimension (msa input)
137139
dmf - additional msa feats derived from msa (has_deletion and deletion_value)
138140
dtf - additional token feats derived from msa (profile and deletion_mean)
141+
dac - additional pairwise token constraint embeddings
142+
dpe - additional protein language model embeddings from esm
139143
t - templates
140144
s - msa
141145
r - registers
@@ -5951,6 +5955,9 @@ def __init__(
59515955
checkpoint_diffusion_module = False,
59525956
detach_when_recycling = True,
59535957
pdb_training_set=True,
5958+
plm_embeddings: Literal["esm2_t33_650M_UR50D"] | None = None,
5959+
plm_repr_layer: int = 33,
5960+
constraint_embeddings: int | None = None,
59545961
):
59555962
super().__init__()
59565963

@@ -5976,6 +5983,13 @@ def __init__(
59765983
self.has_atom_embeds = has_atom_embeds
59775984
self.has_atompair_embeds = has_atompair_embeds
59785985

5986+
# optional pairwise token constraint embeddings
5987+
5988+
self.constraint_embeddings = constraint_embeddings
5989+
5990+
if exists(constraint_embeddings):
5991+
self.constraint_embeds = LinearNoBias(constraint_embeddings, dim_pairwise)
5992+
59795993
# residue or nucleotide modifications
59805994

59815995
num_molecule_mods = default(num_molecule_mods, 0)
@@ -5986,6 +6000,18 @@ def __init__(
59866000

59876001
self.has_molecule_mod_embeds = has_molecule_mod_embeds
59886002

6003+
# optional protein language model (PLM) embeddings
6004+
6005+
self.plm_embeddings = plm_embeddings
6006+
6007+
if exists(plm_embeddings):
6008+
self.plm, plm_alphabet = esm.pretrained.load_model_and_alphabet_hub(plm_embeddings)
6009+
self.plm_repr_layer = plm_repr_layer
6010+
self.plm_batch_converter = plm_alphabet.get_batch_converter()
6011+
self.plm_embeds = nn.Linear(self.plm.embed_dim, dim_single, bias=False)
6012+
for p in self.plm.parameters():
6013+
p.requires_grad = False
6014+
59896015
# atoms per window
59906016

59916017
self.atoms_per_window = atoms_per_window
@@ -6295,6 +6321,40 @@ def shrink_and_perturb_(
62956321

62966322
return self
62976323

6324+
@typecheck
6325+
def extract_plm_embeddings(self, aa_ids: Int['b n']) -> Float['b n dpe']:
6326+
aa_constants = get_residue_constants(res_chem_index=IS_PROTEIN)
6327+
sequence_data = [
6328+
(
6329+
f"molecule{i}",
6330+
"".join(
6331+
[
6332+
(
6333+
aa_constants.restypes[id]
6334+
if 0 <= id < len(aa_constants.restypes)
6335+
else "X"
6336+
)
6337+
for id in ids
6338+
]
6339+
),
6340+
)
6341+
for i, ids in enumerate(aa_ids)
6342+
]
6343+
6344+
_, _, batch_tokens = self.plm_batch_converter(sequence_data)
6345+
batch_tokens = batch_tokens.to(self.device)
6346+
6347+
with torch.no_grad():
6348+
results = self.plm(batch_tokens, repr_layers=[self.plm_repr_layer])
6349+
token_representations = results["representations"][self.plm_repr_layer]
6350+
6351+
sequence_representations = []
6352+
for i, (_, seq) in enumerate(sequence_data):
6353+
sequence_representations.append(token_representations[i, 1 : len(seq) + 1])
6354+
plm_embeddings = torch.stack(sequence_representations, dim=0)
6355+
6356+
return plm_embeddings
6357+
62986358
@typecheck
62996359
def forward_with_alphafold3_inputs(
63006360
self,
@@ -6342,6 +6402,7 @@ def forward(
63426402
distance_labels: Int['b n n'] | Int['b m m'] | None = None,
63436403
resolved_labels: Int['b m'] | None = None,
63446404
resolution: Float[' b'] | None = None,
6405+
token_constraints: Int['b n n dac'] | None = None,
63456406
return_loss_breakdown = False,
63466407
return_loss: bool = None,
63476408
return_all_diffused_atom_pos: bool = False,
@@ -6488,6 +6549,30 @@ def forward(
64886549

64896550
single_init = seq_unpack_one(single_init)
64906551

6552+
# handle maybe pairwise token constraint embeddings
6553+
6554+
if exists(self.constraint_embeddings):
6555+
assert exists(
6556+
token_constraints
6557+
), "`token_constraints` must be provided to use constraint embeddings."
6558+
6559+
pairwise_constraint_embeds = self.constraint_embeds(token_constraints)
6560+
pairwise_init = pairwise_init + pairwise_constraint_embeds
6561+
6562+
# handle maybe protein language model (PLM) embeddings
6563+
6564+
if exists(self.plm_embeddings):
6565+
molecule_aa_ids = torch.where(
6566+
molecule_ids < 0,
6567+
NUM_HUMAN_AMINO_ACIDS,
6568+
molecule_ids.clamp(max=NUM_HUMAN_AMINO_ACIDS),
6569+
)
6570+
6571+
molecule_plm_embeddings = self.extract_plm_embeddings(molecule_aa_ids)
6572+
single_plm_init = self.plm_embeds(molecule_plm_embeddings)
6573+
6574+
single_init = single_init + single_plm_init
6575+
64916576
# relative positional encoding
64926577

64936578
relative_position_encoding = self.relative_position_encoding(

0 commit comments

Comments
 (0)