11from __future__ import annotations
22
3+ import esm
34import random
45import sh
56from math import pi , sqrt
6364 IS_RNA ,
6465 IS_LIGAND ,
6566 IS_METAL_ION ,
67+ NUM_HUMAN_AMINO_ACIDS ,
6668 NUM_MOLECULE_IDS ,
6769 NUM_MSA_ONE_HOT ,
6870 DEFAULT_NUM_MOLECULE_MODS ,
136138dmi - feature dimension (msa input)
137139dmf - additional msa feats derived from msa (has_deletion and deletion_value)
138140dtf - additional token feats derived from msa (profile and deletion_mean)
141+ dac - additional pairwise token constraint embeddings
142+ dpe - additional protein language model embeddings from esm
139143t - templates
140144s - msa
141145r - registers
@@ -5951,6 +5955,9 @@ def __init__(
59515955 checkpoint_diffusion_module = False ,
59525956 detach_when_recycling = True ,
59535957 pdb_training_set = True ,
5958+ plm_embeddings : Literal ["esm2_t33_650M_UR50D" ] | None = None ,
5959+ plm_repr_layer : int = 33 ,
5960+ constraint_embeddings : int | None = None ,
59545961 ):
59555962 super ().__init__ ()
59565963
@@ -5976,6 +5983,13 @@ def __init__(
59765983 self .has_atom_embeds = has_atom_embeds
59775984 self .has_atompair_embeds = has_atompair_embeds
59785985
5986+ # optional pairwise token constraint embeddings
5987+
5988+ self .constraint_embeddings = constraint_embeddings
5989+
5990+ if exists (constraint_embeddings ):
5991+ self .constraint_embeds = LinearNoBias (constraint_embeddings , dim_pairwise )
5992+
59795993 # residue or nucleotide modifications
59805994
59815995 num_molecule_mods = default (num_molecule_mods , 0 )
@@ -5986,6 +6000,18 @@ def __init__(
59866000
59876001 self .has_molecule_mod_embeds = has_molecule_mod_embeds
59886002
6003+ # optional protein language model (PLM) embeddings
6004+
6005+ self .plm_embeddings = plm_embeddings
6006+
6007+ if exists (plm_embeddings ):
6008+ self .plm , plm_alphabet = esm .pretrained .load_model_and_alphabet_hub (plm_embeddings )
6009+ self .plm_repr_layer = plm_repr_layer
6010+ self .plm_batch_converter = plm_alphabet .get_batch_converter ()
6011+ self .plm_embeds = nn .Linear (self .plm .embed_dim , dim_single , bias = False )
6012+ for p in self .plm .parameters ():
6013+ p .requires_grad = False
6014+
59896015 # atoms per window
59906016
59916017 self .atoms_per_window = atoms_per_window
@@ -6295,6 +6321,40 @@ def shrink_and_perturb_(
62956321
62966322 return self
62976323
6324+ @typecheck
6325+ def extract_plm_embeddings (self , aa_ids : Int ['b n' ]) -> Float ['b n dpe' ]:
6326+ aa_constants = get_residue_constants (res_chem_index = IS_PROTEIN )
6327+ sequence_data = [
6328+ (
6329+ f"molecule{ i } " ,
6330+ "" .join (
6331+ [
6332+ (
6333+ aa_constants .restypes [id ]
6334+ if 0 <= id < len (aa_constants .restypes )
6335+ else "X"
6336+ )
6337+ for id in ids
6338+ ]
6339+ ),
6340+ )
6341+ for i , ids in enumerate (aa_ids )
6342+ ]
6343+
6344+ _ , _ , batch_tokens = self .plm_batch_converter (sequence_data )
6345+ batch_tokens = batch_tokens .to (self .device )
6346+
6347+ with torch .no_grad ():
6348+ results = self .plm (batch_tokens , repr_layers = [self .plm_repr_layer ])
6349+ token_representations = results ["representations" ][self .plm_repr_layer ]
6350+
6351+ sequence_representations = []
6352+ for i , (_ , seq ) in enumerate (sequence_data ):
6353+ sequence_representations .append (token_representations [i , 1 : len (seq ) + 1 ])
6354+ plm_embeddings = torch .stack (sequence_representations , dim = 0 )
6355+
6356+ return plm_embeddings
6357+
62986358 @typecheck
62996359 def forward_with_alphafold3_inputs (
63006360 self ,
@@ -6342,6 +6402,7 @@ def forward(
63426402 distance_labels : Int ['b n n' ] | Int ['b m m' ] | None = None ,
63436403 resolved_labels : Int ['b m' ] | None = None ,
63446404 resolution : Float [' b' ] | None = None ,
6405+ token_constraints : Int ['b n n dac' ] | None = None ,
63456406 return_loss_breakdown = False ,
63466407 return_loss : bool = None ,
63476408 return_all_diffused_atom_pos : bool = False ,
@@ -6488,6 +6549,30 @@ def forward(
64886549
64896550 single_init = seq_unpack_one (single_init )
64906551
6552+ # handle maybe pairwise token constraint embeddings
6553+
6554+ if exists (self .constraint_embeddings ):
6555+ assert exists (
6556+ token_constraints
6557+ ), "`token_constraints` must be provided to use constraint embeddings."
6558+
6559+ pairwise_constraint_embeds = self .constraint_embeds (token_constraints )
6560+ pairwise_init = pairwise_init + pairwise_constraint_embeds
6561+
6562+ # handle maybe protein language model (PLM) embeddings
6563+
6564+ if exists (self .plm_embeddings ):
6565+ molecule_aa_ids = torch .where (
6566+ molecule_ids < 0 ,
6567+ NUM_HUMAN_AMINO_ACIDS ,
6568+ molecule_ids .clamp (max = NUM_HUMAN_AMINO_ACIDS ),
6569+ )
6570+
6571+ molecule_plm_embeddings = self .extract_plm_embeddings (molecule_aa_ids )
6572+ single_plm_init = self .plm_embeds (molecule_plm_embeddings )
6573+
6574+ single_init = single_init + single_plm_init
6575+
64916576 # relative positional encoding
64926577
64936578 relative_position_encoding = self .relative_position_encoding (
0 commit comments