Skip to content

Commit f7cdbef

Browse files
authored
Add NLMs (#270)
* Update pyproject.toml * Create nlm.py * Update data_utils.py * Update plm.py * Update inputs.py * Update alphafold3.py * Update test_af3.py
1 parent 0bc7541 commit f7cdbef

File tree

7 files changed

+348
-31
lines changed

7 files changed

+348
-31
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
CONSTRAINTS,
5959
CONSTRAINTS_MASK_VALUE,
6060
IS_MOLECULE_TYPES,
61+
IS_NA_INDICES,
62+
IS_NON_NA_INDICES,
6163
IS_PROTEIN_INDEX,
6264
IS_DNA_INDEX,
6365
IS_RNA_INDEX,
@@ -70,6 +72,9 @@
7072
IS_RNA,
7173
IS_LIGAND,
7274
IS_METAL_ION,
75+
MAX_DNA_NUCLEOTIDE_ID,
76+
MIN_RNA_NUCLEOTIDE_ID,
77+
MISSING_RNA_NUCLEOTIDE_ID,
7378
NUM_HUMAN_AMINO_ACIDS,
7479
NUM_MOLECULE_IDS,
7580
NUM_MSA_ONE_HOT,
@@ -85,6 +90,11 @@
8590
get_residue_constants,
8691
)
8792

93+
from alphafold3_pytorch.nlm import (
94+
NLMEmbedding,
95+
NLMRegistry,
96+
remove_nlms
97+
)
8898
from alphafold3_pytorch.plm import (
8999
PLMEmbedding,
90100
PLMRegistry,
@@ -149,7 +159,8 @@
149159
dmf - additional msa feats derived from msa (has_deletion and deletion_value)
150160
dtf - additional token feats derived from msa (profile and deletion_mean)
151161
dac - additional pairwise token constraint embeddings
152-
dpe - additional protein language model embeddings from esm
162+
dpe - additional protein language model embeddings
163+
dne - additional nucleotide language model embeddings
153164
t - templates
154165
s - msa
155166
r - registers
@@ -5957,7 +5968,9 @@ def __init__(
59575968
detach_when_recycling = True,
59585969
pdb_training_set=True,
59595970
plm_embeddings: PLMEmbedding | tuple[PLMEmbedding, ...] | None = None,
5971+
nlm_embeddings: NLMEmbedding | tuple[NLMEmbedding, ...] | None = None,
59605972
plm_kwargs: dict | tuple[dict, ...] | None = None,
5973+
nlm_kwargs: dict | tuple[dict, ...] | None = None,
59615974
constraints: List[CONSTRAINTS] | None = None,
59625975
):
59635976
super().__init__()
@@ -6033,6 +6046,34 @@ def __init__(
60336046

60346047
self.to_plm_embeds = LinearNoBias(concatted_plm_embed_dim, dim_single)
60356048

6049+
# optional nucleotide language model(s) (NLM) embeddings
6050+
6051+
self.nlms = None
6052+
6053+
if exists(nlm_embeddings):
6054+
self.nlms = ModuleList([])
6055+
6056+
for one_nlm_embedding, one_nlm_kwargs in zip_longest(
6057+
cast_tuple(nlm_embeddings), cast_tuple(nlm_kwargs)
6058+
):
6059+
assert (
6060+
one_nlm_embedding in NLMRegistry
6061+
), f"Received invalid NLM embedding name: {one_nlm_embedding}. Acceptable ones are {list(NLMRegistry.keys())}."
6062+
6063+
constructor = NLMRegistry.get(one_nlm_embedding)
6064+
6065+
one_nlm_kwargs = default(one_nlm_kwargs, {})
6066+
nlm = constructor(**one_nlm_kwargs)
6067+
6068+
freeze_(nlm)
6069+
6070+
self.nlms.append(nlm)
6071+
6072+
if exists(self.nlms):
6073+
concatted_nlm_embed_dim = sum([nlm.embed_dim for nlm in self.nlms])
6074+
6075+
self.to_nlm_embeds = LinearNoBias(concatted_nlm_embed_dim, dim_single)
6076+
60366077
# atoms per window
60376078

60386079
self.atoms_per_window = atoms_per_window
@@ -6261,10 +6302,12 @@ def device(self):
62616302
return self.zero.device
62626303

62636304
@remove_plms
6305+
@remove_nlms
62646306
def state_dict(self, *args, **kwargs):
62656307
return super().state_dict(*args, **kwargs)
62666308

62676309
@remove_plms
6310+
@remove_nlms
62686311
def load_state_dict(self, *args, **kwargs):
62696312
return super().load_state_dict(*args, **kwargs)
62706313

@@ -6590,6 +6633,32 @@ def forward(
65906633

65916634
single_init = single_init + single_plm_init
65926635

6636+
# handle maybe nucleotide language model (NLM) embeddings
6637+
6638+
if exists(self.nlms):
6639+
na_ids = torch.where(
6640+
is_molecule_types[..., IS_NA_INDICES].any(dim=-1)
6641+
& (
6642+
(molecule_ids < MIN_RNA_NUCLEOTIDE_ID) | (molecule_ids > MAX_DNA_NUCLEOTIDE_ID)
6643+
),
6644+
MISSING_RNA_NUCLEOTIDE_ID,
6645+
molecule_ids,
6646+
)
6647+
molecule_na_ids = torch.where(
6648+
is_molecule_types[..., IS_NON_NA_INDICES].any(dim=-1),
6649+
-1,
6650+
na_ids,
6651+
)
6652+
6653+
nlm_embeds = [nlm(molecule_na_ids) for nlm in self.nlms]
6654+
6655+
# concat all NLM embeddings and project and add to single init
6656+
6657+
all_nlm_embeds = torch.cat(nlm_embeds, dim=-1)
6658+
single_nlm_init = self.to_nlm_embeds(all_nlm_embeds)
6659+
6660+
single_init = single_init + single_nlm_init
6661+
65936662
# relative positional encoding
65946663

65956664
relative_position_encoding = self.relative_position_encoding(

alphafold3_pytorch/inputs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,11 @@
123123
IS_DNA_INDEX = 2
124124
IS_LIGAND_INDEX = -2
125125
IS_METAL_ION_INDEX = -1
126+
126127
IS_BIOMOLECULE_INDICES = slice(0, 3)
127128
IS_NON_PROTEIN_INDICES = slice(1, 5)
129+
IS_NA_INDICES = slice(1, 3)
130+
IS_NON_NA_INDICES = [0, 3, 4]
128131

129132
IS_PROTEIN, IS_RNA, IS_DNA, IS_LIGAND, IS_METAL_ION = tuple(
130133
(IS_MOLECULE_TYPES + i if i < 0 else i)
@@ -144,6 +147,11 @@
144147
NUM_HUMAN_AMINO_ACIDS = len(HUMAN_AMINO_ACIDS) - 1 # exclude unknown amino acid type
145148
NUM_MSA_ONE_HOT = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) + 1
146149

150+
MIN_RNA_NUCLEOTIDE_ID = len(HUMAN_AMINO_ACIDS)
151+
MAX_DNA_NUCLEOTIDE_ID = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) - 1
152+
153+
MISSING_RNA_NUCLEOTIDE_ID = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) - 1
154+
147155
DEFAULT_NUM_MOLECULE_MODS = 4 # `mod_protein`, `mod_rna`, `mod_dna`, and `mod_unk`
148156
ADDITIONAL_MOLECULE_FEATS = 5
149157

alphafold3_pytorch/nlm.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from functools import wraps
2+
3+
import torch
4+
from beartype.typing import Literal
5+
from torch import tensor
6+
from torch.nn import Module
7+
8+
from alphafold3_pytorch.common.biomolecule import get_residue_constants
9+
from alphafold3_pytorch.inputs import IS_DNA, IS_RNA
10+
from alphafold3_pytorch.tensor_typing import Float, Int, typecheck
11+
from alphafold3_pytorch.utils.data_utils import join
12+
13+
# functions
14+
15+
def remove_nlms(fn):
16+
"""Decorator to remove NLMs from the model before calling the inner function and then restore
17+
them afterwards."""
18+
19+
@wraps(fn)
20+
def inner(self, *args, **kwargs):
21+
has_nlms = hasattr(self, "nlms")
22+
if has_nlms:
23+
nlms = self.nlms
24+
delattr(self, "nlms")
25+
26+
out = fn(self, *args, **kwargs)
27+
28+
if has_nlms:
29+
self.nlms = nlms
30+
31+
return out
32+
33+
return inner
34+
35+
36+
# constants
37+
38+
rna_constants = get_residue_constants(res_chem_index=IS_RNA)
39+
dna_constants = get_residue_constants(res_chem_index=IS_DNA)
40+
41+
rna_restypes = rna_constants.restypes + ["X"]
42+
dna_restypes = dna_constants.restypes + ["X"]
43+
44+
rna_min_restype_num = rna_constants.min_restype_num
45+
dna_min_restype_num = dna_constants.min_restype_num
46+
47+
RINALMO_MASK_TOKEN = "-" # nosec
48+
49+
# class
50+
51+
52+
class RiNALMoWrapper(Module):
53+
"""A wrapper for the RiNALMo model to provide NLM embeddings."""
54+
55+
def __init__(self):
56+
super().__init__()
57+
from multimolecule import RiNALMoModel, RnaTokenizer
58+
59+
self.register_buffer("dummy", tensor(0), persistent=False)
60+
61+
self.tokenizer = RnaTokenizer.from_pretrained(
62+
"multimolecule/rinalmo", replace_T_with_U=False
63+
)
64+
self.model = RiNALMoModel.from_pretrained("multimolecule/rinalmo")
65+
66+
self.embed_dim = 1280
67+
68+
@torch.no_grad()
69+
@typecheck
70+
def forward(
71+
self, na_ids: Int["b n"] # type: ignore
72+
) -> Float["b n dne"]: # type: ignore
73+
"""Get NLM embeddings for a batch of (pseudo-)nucleotide sequences.
74+
75+
:param na_ids: A batch of nucleotide residue indices.
76+
:return: The NLM embeddings for the input sequences.
77+
"""
78+
device, seq_len = self.dummy.device, na_ids.shape[-1]
79+
80+
sequence_data = [
81+
join(
82+
[
83+
(
84+
RINALMO_MASK_TOKEN
85+
if i == -1
86+
else (
87+
rna_restypes[i - rna_min_restype_num]
88+
if rna_min_restype_num <= i < dna_min_restype_num
89+
else dna_restypes[i - dna_min_restype_num]
90+
)
91+
)
92+
for i in ids
93+
]
94+
)
95+
for ids in na_ids
96+
]
97+
98+
# encode to ids
99+
100+
inputs = self.tokenizer(sequence_data, return_tensors="pt").to(device)
101+
102+
# forward through nlm
103+
104+
embeddings = self.model(inputs.input_ids, attention_mask=inputs.attention_mask)
105+
106+
# remove prefix
107+
108+
nlm_embeddings = embeddings.last_hidden_state[:, 1 : (seq_len + 1)]
109+
110+
return nlm_embeddings
111+
112+
113+
# NLM embedding type and registry
114+
115+
NLMRegistry = dict(rinalmo=RiNALMoWrapper)
116+
117+
NLMEmbedding = Literal["rinalmo"]

alphafold3_pytorch/plm.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,39 @@
99
from alphafold3_pytorch.common.biomolecule import get_residue_constants
1010
from alphafold3_pytorch.inputs import IS_PROTEIN
1111
from alphafold3_pytorch.tensor_typing import Float, Int, typecheck
12+
from alphafold3_pytorch.utils.data_utils import join
1213

1314
# functions
1415

15-
def join(arr, delimiter = ''): # just redo an ugly part of python
16-
return delimiter.join(arr)
1716

1817
def remove_plms(fn):
18+
"""Decorator to remove PLMs from the model before calling the inner function and then restore
19+
them afterwards."""
20+
1921
@wraps(fn)
2022
def inner(self, *args, **kwargs):
21-
has_plms = hasattr(self, 'plms')
23+
has_plms = hasattr(self, "plms")
2224
if has_plms:
2325
plms = self.plms
24-
delattr(self, 'plms')
26+
delattr(self, "plms")
2527

2628
out = fn(self, *args, **kwargs)
2729

2830
if has_plms:
2931
self.plms = plms
3032

3133
return out
34+
3235
return inner
3336

37+
3438
# constants
3539

3640
aa_constants = get_residue_constants(res_chem_index=IS_PROTEIN)
3741
restypes = aa_constants.restypes + ["X"]
3842

39-
ESM_MASK_TOKEN = "-"
40-
PROST_T5_MASK_TOKEN = "X"
43+
ESM_MASK_TOKEN = "-" # nosec
44+
PROST_T5_MASK_TOKEN = "X" # nosec
4145

4246
# class
4347

@@ -70,7 +74,9 @@ def forward(
7074
:param aa_ids: A batch of amino acid residue indices.
7175
:return: The PLM embeddings for the input sequences.
7276
"""
73-
device, repr_layer = self.dummy.device, self.repr_layer
77+
device, seq_len, repr_layer = self.dummy.device, aa_ids.shape[-1], self.repr_layer
78+
79+
# following the readme at https://github.com/facebookresearch/esm
7480

7581
sequence_data = [
7682
(
@@ -80,18 +86,21 @@ def forward(
8086
for mol_idx, ids in enumerate(aa_ids)
8187
]
8288

89+
# encode to IDs
90+
8391
_, _, batch_tokens = self.batch_converter(sequence_data)
8492
batch_tokens = batch_tokens.to(device)
8593

94+
# forward through plm
95+
8696
self.model.eval()
8797
results = self.model(batch_tokens, repr_layers=[repr_layer])
8898

89-
token_representations = results["representations"][repr_layer]
99+
embeddings = results["representations"][repr_layer]
100+
101+
# remove prefix
90102

91-
sequence_representations = []
92-
for i, (_, seq) in enumerate(sequence_data):
93-
sequence_representations.append(token_representations[i, 1 : len(seq) + 1])
94-
plm_embeddings = torch.stack(sequence_representations, dim=0)
103+
plm_embeddings = embeddings[:, 1 : (seq_len + 1)]
95104

96105
return plm_embeddings
97106

@@ -121,20 +130,21 @@ def forward(
121130
"""
122131
device, seq_len = self.dummy.device, aa_ids.shape[-1]
123132

124-
str_sequences = [
125-
join([(PROST_T5_MASK_TOKEN if i == -1 else restypes[i]) for i in ids]) for ids in aa_ids
126-
]
127-
128133
# following the readme at https://github.com/mheinzinger/ProstT5
129134

130-
str_sequences = [
131-
join(list(re.sub(r"[UZOB]", "X", str_seq)), " ") for str_seq in str_sequences
135+
sequence_data = [
136+
join([(PROST_T5_MASK_TOKEN if i == -1 else restypes[i]) for i in ids])
137+
for ids in aa_ids
138+
]
139+
140+
sequence_data = [
141+
join(list(re.sub(r"[UZOB]", "X", str_seq)), " ") for str_seq in sequence_data
132142
]
133143

134144
# encode to ids
135145

136146
inputs = self.tokenizer.batch_encode_plus(
137-
str_sequences, add_special_tokens=True, padding="longest", return_tensors="pt"
147+
sequence_data, add_special_tokens=True, padding="longest", return_tensors="pt"
138148
).to(device)
139149

140150
# forward through plm
@@ -143,8 +153,8 @@ def forward(
143153

144154
# remove prefix
145155

146-
plm_embedding = embeddings.last_hidden_state[:, 1 : (seq_len + 1)]
147-
return plm_embedding
156+
plm_embeddings = embeddings.last_hidden_state[:, 1 : (seq_len + 1)]
157+
return plm_embeddings
148158

149159

150160
# PLM embedding type and registry

0 commit comments

Comments
 (0)