Skip to content

Commit 477ad19

Browse files
authored
Update plm.py (#257)
1 parent 1f3e701 commit 477ad19

File tree

1 file changed

+44
-50
lines changed

1 file changed

+44
-50
lines changed

alphafold3_pytorch/plm.py

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,13 @@
22
from functools import partial
33

44
import torch
5+
from beartype.typing import Literal
56
from torch import tensor
67
from torch.nn import Module
78

8-
from beartype.typing import Literal
9-
10-
from alphafold3_pytorch.tensor_typing import (
11-
typecheck,
12-
Float,
13-
Int
14-
)
15-
16-
from alphafold3_pytorch.common.biomolecule import (
17-
get_residue_constants,
18-
)
19-
20-
from alphafold3_pytorch.inputs import (
21-
IS_PROTEIN,
22-
)
9+
from alphafold3_pytorch.common.biomolecule import get_residue_constants
10+
from alphafold3_pytorch.inputs import IS_PROTEIN
11+
from alphafold3_pytorch.tensor_typing import Float, Int, typecheck
2312

2413
# functions
2514

@@ -28,41 +17,48 @@ def join(arr, delimiter = ''): # just redo an ugly part of python
2817

2918
# constants
3019

31-
aa_constants = get_residue_constants(res_chem_index = IS_PROTEIN)
32-
restypes_index = dict(enumerate(aa_constants.restypes))
20+
aa_constants = get_residue_constants(res_chem_index=IS_PROTEIN)
21+
restypes = aa_constants.restypes + ["X"]
3322

3423
# class
3524

25+
3626
class ESMWrapper(Module):
27+
"""A wrapper for the ESM model to provide PLM embeddings."""
28+
3729
def __init__(
3830
self,
39-
esm_name,
40-
repr_layer = 33
31+
esm_name: str,
32+
repr_layer: int = 33,
4133
):
4234
super().__init__()
4335
import esm
36+
4437
self.repr_layer = repr_layer
4538
self.model, alphabet = esm.pretrained.load_model_and_alphabet_hub(esm_name)
4639
self.batch_converter = alphabet.get_batch_converter()
4740

4841
self.embed_dim = self.model.embed_dim
49-
self.register_buffer('dummy', tensor(0), persistent = False)
42+
self.register_buffer("dummy", tensor(0), persistent=False)
5043

5144
@torch.no_grad()
5245
@typecheck
5346
def forward(
54-
self,
55-
aa_ids: Int['b n']
56-
) -> Float['b n dpe']:
47+
self, aa_ids: Int["b n"] # type: ignore
48+
) -> Float["b n dpe"]: # type: ignore
49+
"""Get PLM embeddings for a batch of (pseudo-)protein sequences.
5750
51+
:param aa_ids: A batch of amino acid residue indices.
52+
:return: The PLM embeddings for the input sequences.
53+
"""
5854
device, repr_layer = self.dummy.device, self.repr_layer
5955

6056
sequence_data = [
6157
(
62-
f"molecule{i}",
63-
join([restypes_index.get(i, 'X') for i in ids]),
58+
f"molecule{mol_idx}",
59+
join([(restypes[i] if 0 <= i < len(restypes) else "X") for i in ids]),
6460
)
65-
for i, ids in enumerate(aa_ids)
61+
for mol_idx, ids in enumerate(aa_ids)
6662
]
6763

6864
_, _, batch_tokens = self.batch_converter(sequence_data)
@@ -80,64 +76,62 @@ def forward(
8076

8177
return plm_embeddings
8278

79+
8380
class ProstT5Wrapper(Module):
81+
"""A wrapper for the ProstT5 model to provide PLM embeddings."""
82+
8483
def __init__(self):
8584
super().__init__()
86-
from transformers import T5Tokenizer, T5EncoderModel
85+
from transformers import T5EncoderModel, T5Tokenizer
8786

88-
self.register_buffer('dummy', tensor(0), persistent = False)
87+
self.register_buffer("dummy", tensor(0), persistent=False)
8988

90-
self.tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case = False)
89+
self.tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False)
9190
self.model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
9291
self.embed_dim = 1024
9392

9493
@torch.no_grad()
9594
@typecheck
9695
def forward(
97-
self,
98-
aa_ids: Int['b n']
99-
) -> Float['b n dpe']:
96+
self, aa_ids: Int["b n"] # type: ignore
97+
) -> Float["b n dpe"]: # type: ignore
98+
"""Get PLM embeddings for a batch of (pseudo-)protein sequences.
10099
100+
:param aa_ids: A batch of amino acid residue indices.
101+
:return: The PLM embeddings for the input sequences.
102+
"""
101103
device, seq_len = self.dummy.device, aa_ids.shape[-1]
102104

103105
str_sequences = [
104-
join([restypes_index.get(i, 'X') for i in ids])
105-
for i, ids in enumerate(aa_ids)
106+
join([(restypes[i] if 0 <= i < len(restypes) else "X") for i in ids]) for ids in aa_ids
106107
]
107108

108109
# following the readme at https://github.com/mheinzinger/ProstT5
109110

110-
str_sequences = [join(list(re.sub(r"[UZOB]", "X", str_seq)), ' ') for str_seq in str_sequences]
111+
str_sequences = [
112+
join(list(re.sub(r"[UZOB]", "X", str_seq)), " ") for str_seq in str_sequences
113+
]
111114

112115
# encode to ids
113116

114117
inputs = self.tokenizer.batch_encode_plus(
115-
str_sequences,
116-
add_special_tokens = True,
117-
padding = "longest",
118-
return_tensors = 'pt'
118+
str_sequences, add_special_tokens=True, padding="longest", return_tensors="pt"
119119
).to(device)
120120

121121
# forward through plm
122122

123-
embeddings = self.model(
124-
inputs.input_ids,
125-
attention_mask = inputs.attention_mask
126-
)
123+
embeddings = self.model(inputs.input_ids, attention_mask=inputs.attention_mask)
127124

128125
# remove prefix
129126

130-
plm_embedding = embeddings.last_hidden_state[:, 1:(seq_len + 1)]
127+
plm_embedding = embeddings.last_hidden_state[:, 1 : (seq_len + 1)]
131128
return plm_embedding
132129

130+
133131
# PLM embedding type and registry
134132

135133
PLMRegistry = dict(
136-
esm2_t33_650M_UR50D = partial(ESMWrapper, 'esm2_t33_650M_UR50D'),
137-
prostT5 = ProstT5Wrapper
134+
esm2_t33_650M_UR50D=partial(ESMWrapper, "esm2_t33_650M_UR50D"), prostT5=ProstT5Wrapper
138135
)
139136

140-
PLMEmbedding = Literal[
141-
"esm2_t33_650M_UR50D",
142-
"prostT5"
143-
]
137+
PLMEmbedding = Literal["esm2_t33_650M_UR50D", "prostT5"]

0 commit comments

Comments
 (0)