Skip to content

Commit c6b8c34

Browse files
authored
ported over updates from internal (#208)
Co-authored-by: chetan <[email protected]>
1 parent 82b1431 commit c6b8c34

File tree

13 files changed

+2392
-53
lines changed

13 files changed

+2392
-53
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
esm.egg-info
1+
esm.egg-info
2+
# pixi environments
3+
.pixi
4+
*.egg-info

cookbook/snippets/esmc.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,35 @@
1+
import os
2+
13
from esm.models.esmc import ESMC
4+
from esm.sdk import client
25
from esm.sdk.api import (
36
ESMCInferenceClient,
47
ESMProtein,
8+
ESMProteinTensor,
59
LogitsConfig,
610
LogitsOutput,
711
)
12+
from esm.sdk.forge import ESM3ForgeInferenceClient
813

914

10-
def main(client: ESMCInferenceClient):
15+
def main(client: ESMCInferenceClient | ESM3ForgeInferenceClient):
1116
# ================================================================
1217
# Example usage: one single protein
1318
# ================================================================
1419
protein = ESMProtein(sequence="AAAAA")
1520

1621
# Use logits endpoint. Using bf16 for inference optimization
1722
protein_tensor = client.encode(protein)
23+
assert isinstance(
24+
protein_tensor, ESMProteinTensor
25+
), f"Expected ESMProteinTensor but got error: {protein_tensor}"
1826
output = client.logits(
1927
protein_tensor,
2028
LogitsConfig(sequence=True, return_embeddings=True, return_hidden_states=True),
2129
)
2230
assert isinstance(
2331
output, LogitsOutput
24-
), f"LogitsOutput was expected but got {output}"
32+
), f"LogitsOutput was expected but got error: {output}"
2533
assert output.logits is not None and output.logits.sequence is not None
2634
assert output.embeddings is not None
2735
assert output.hidden_states is not None
@@ -30,9 +38,15 @@ def main(client: ESMCInferenceClient):
3038
)
3139

3240
# request a specific hidden layer.
41+
assert isinstance(
42+
protein_tensor, ESMProteinTensor
43+
), f"Expected ESMProteinTensor but got error: {protein_tensor}"
3344
output = client.logits(
3445
protein_tensor, LogitsConfig(return_hidden_states=True, ith_hidden_layer=1)
3546
)
47+
assert isinstance(
48+
output, LogitsOutput
49+
), f"LogitsOutput was expected but got error: {output}"
3650
assert output.hidden_states is not None
3751
print(f"Client returned hidden states with shape {output.hidden_states.shape}")
3852

@@ -57,6 +71,15 @@ def raw_forward(model: ESMC):
5771

5872

5973
if __name__ == "__main__":
60-
model = ESMC.from_pretrained("esmc_300m")
61-
main(model)
62-
raw_forward(model)
74+
if os.environ.get("ESM_API_KEY", ""):
75+
print("ESM_API_KEY found. Trying to use model from Forge...")
76+
main(client(model="esmc-300m-2024-12"))
77+
else:
78+
print("No ESM_API_KEY found. Trying to load model locally...")
79+
print(
80+
"TO try this script with a Forge API, please run ESM_API_KEY=your_api_key python esm3.py"
81+
)
82+
main(ESMC.from_pretrained("esm3_sm_open_v1"))
83+
model = ESMC.from_pretrained("esmc_300m")
84+
main(model)
85+
raw_forward(model)

cookbook/tutorials/3_gfp_design.ipynb

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
},
5151
{
5252
"cell_type": "code",
53-
"execution_count": null,
53+
"execution_count": 13,
5454
"metadata": {
5555
"id": "poK5NTaXRGcX"
5656
},
@@ -85,7 +85,7 @@
8585
},
8686
{
8787
"cell_type": "code",
88-
"execution_count": null,
88+
"execution_count": 14,
8989
"metadata": {
9090
"id": "zNrU9Q2SYonX"
9191
},
@@ -105,7 +105,7 @@
105105
},
106106
{
107107
"cell_type": "code",
108-
"execution_count": null,
108+
"execution_count": 26,
109109
"metadata": {
110110
"id": "Tna_mjGOjdXA"
111111
},
@@ -367,7 +367,8 @@
367367
"source": [
368368
"%%time\n",
369369
"\n",
370-
"num_tokens_to_decode = (prompt.sequence == 32).sum().item()\n",
370+
"# Based on internal, there's not a benefit to iterative decoding past 20 steps\n",
371+
"num_tokens_to_decode = min((prompt.sequence == 32).sum().item(), 20)\n",
371372
"\n",
372373
"sequence_generation = model.generate(\n",
373374
" # Generate a sequence.\n",
@@ -380,7 +381,7 @@
380381
"length_of_sequence = sequence_generation.sequence.numel() - 2\n",
381382
"sequence_generation = model.generate(\n",
382383
" sequence_generation,\n",
383-
" GenerationConfig(track=\"structure\", num_steps=length_of_sequence, temperature=0.0),\n",
384+
" GenerationConfig(track=\"structure\", num_steps=1, temperature=0.0),\n",
384385
")\n",
385386
"\n",
386387
"# Decode to AA string and coordinates.\n",
@@ -528,11 +529,21 @@
528529
"provenance": []
529530
},
530531
"kernelspec": {
531-
"display_name": "Python 3",
532+
"display_name": "default",
533+
"language": "python",
532534
"name": "python3"
533535
},
534536
"language_info": {
535-
"name": "python"
537+
"codemirror_mode": {
538+
"name": "ipython",
539+
"version": 3
540+
},
541+
"file_extension": ".py",
542+
"mimetype": "text/x-python",
543+
"name": "python",
544+
"nbconvert_exporter": "python",
545+
"pygments_lexer": "ipython3",
546+
"version": "3.10.0"
536547
}
537548
},
538549
"nbformat": 4,

esm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = "3.1.4"
1+
__version__ = "3.1.5"
22

esm/sdk/api.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ class ESMProteinError(Exception, ProteinType):
266266
@define
267267
class GenerationConfig:
268268
track: str = ""
269-
invalid_ids: Sequence[int] = []
269+
# By default avoid sampling the amino acid "X"
270+
invalid_ids: Sequence[int] = [24]
270271
# Controls the number of tokens to unmask during each round of iterative generation.
271272
schedule: str = attr.field(
272273
validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
@@ -275,11 +276,11 @@ class GenerationConfig:
275276
# "random" will unmask a correct number of tokens randomly.
276277
# "entropy" will unmask the tokens with the lowest logit entropy first.
277278
strategy: str = attr.field(
278-
validator=attr.validators.in_(["random", "entropy"]), default="entropy"
279+
validator=attr.validators.in_(["random", "entropy"]), default="random"
279280
)
280-
# Set this to a higher value for better generation results.
281+
# Setting default to 20, as there is diminishing return for decoding steps more than 20.
281282
# Note that this needs to be less than or equal to the sequence length.
282-
num_steps: int = 1
283+
num_steps: int = 20
283284
temperature: float = 1.0
284285
temperature_annealing: bool = False
285286
top_p: float = 1.0

esm/sdk/forge.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def log_retry_attempt(retry_state):
6060

6161

6262
def _validate_protein_tensor_input(input):
63+
if isinstance(input, ESMProteinError):
64+
raise ValueError(
65+
f"Input must be an ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}"
66+
)
6367
if not isinstance(input, ESMProteinTensor):
6468
raise ValueError(
6569
f"Input must be an ESMProteinTensor instance, but received {type(input)} instead. "
@@ -71,14 +75,25 @@ class SequenceStructureForgeInferenceClient:
7175
def __init__(
7276
self,
7377
url: str = "https://forge.evolutionaryscale.ai",
78+
model: str | None = None,
7479
token: str = "",
7580
request_timeout: int | None = None,
7681
):
82+
"""
83+
Forge client for folding and inverse folding between sequence and structure spaces.
84+
85+
Args:
86+
url: URL of the Forge server.
87+
model: Name of the model to be used for folding / inv folding.
88+
token: API token.
89+
request_timeout: Override the system default request timeout, in seconds.
90+
"""
7791
if token == "":
7892
raise RuntimeError(
7993
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
8094
)
8195
self.url = url
96+
self.model = model
8297
self.token = token
8398
self.headers = {"Authorization": f"Bearer {self.token}"}
8499
self.request_timeout = request_timeout
@@ -89,9 +104,19 @@ def fold(
89104
potential_sequence_of_concern: bool,
90105
model_name: str | None = None,
91106
) -> ESMProtein | ESMProteinError:
107+
"""Predict coordinates for a protein sequence.
108+
109+
Args:
110+
sequence: Protein sequence to be folded.
111+
potential_sequence_of_concern: Self disclosed potential_of_concern bit.
112+
This bit is largely ignored by the fold() endpoint.
113+
model_name: Override the client level model name if needed.
114+
"""
92115
request = {"sequence": sequence}
93116
if model_name is not None:
94117
request["model"] = model_name
118+
elif self.model is not None:
119+
request["model"] = self.model
95120
try:
96121
data = self._post("fold", request, potential_sequence_of_concern)
97122
except ESMProteinError as e:
@@ -109,6 +134,17 @@ def inverse_fold(
109134
potential_sequence_of_concern: bool,
110135
model_name: str | None = None,
111136
) -> ESMProtein | ESMProteinError:
137+
"""Generate protein sequence from its structure.
138+
139+
This endpoint is only supported by generative models like ESM3.
140+
141+
Args:
142+
coordinates: Protein sequence coordinates to be inversely folded.
143+
config: Configurations related to inverse folding generation.
144+
potential_sequence_of_concern: Self disclosed potential_of_concern bit.
145+
Requires special permission to use.
146+
model_name: Override the client level model name if needed.
147+
"""
112148
inverse_folding_config = {
113149
"invalid_ids": config.invalid_ids,
114150
"temperature": config.temperature,
@@ -119,6 +155,8 @@ def inverse_fold(
119155
}
120156
if model_name is not None:
121157
request["model"] = model_name
158+
elif self.model is not None:
159+
request["model"] = self.model
122160
try:
123161
data = self._post("inverse_fold", request, potential_sequence_of_concern)
124162
except ESMProteinError as e:
@@ -208,6 +246,16 @@ def wrapper(instance, *args, **kwargs):
208246

209247
@retry_decorator
210248
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
249+
if isinstance(input, ESMProteinError):
250+
raise ValueError(
251+
f"Input must be an ESMProtein or ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}"
252+
)
253+
assert isinstance(input, ESMProtein) or isinstance(input, ESMProteinTensor)
254+
if input.sequence is not None and config.num_steps > len(input.sequence):
255+
config.num_steps = len(input.sequence)
256+
print(
257+
"Warning: num_steps cannot exceed sequence length. Setting num_steps to sequence length."
258+
)
211259
if isinstance(input, ESMProtein):
212260
output = self.__generate_protein(input, config)
213261
elif isinstance(input, ESMProteinTensor):

esm/sdk/sagemaker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010

1111
class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
12-
def __init__(self, endpoint_name: str):
12+
def __init__(self, endpoint_name: str, model: str | None = None):
1313
"""SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
1414
1515
Args:
1616
endpoint_name: Name of the SageMaker endpoint.
1717
"""
1818
# Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
19-
super().__init__(url="", token="dummy")
19+
super().__init__(url="", model=model, token="dummy")
2020

2121
self._endpoint_name = endpoint_name
2222

esm/tokenization/sequence_tokenizer.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,45 @@ def bos_token(self):
7272
def bos_token_id(self):
7373
return self.cls_token_id
7474

75+
@property
76+
def cls_token(self):
77+
return self._get_token("cls_token")
78+
79+
@property
80+
def cls_token_id(self):
81+
return self._get_token_id(self.cls_token)
82+
83+
@property
84+
def eos_token(self):
85+
return self._get_token("eos_token")
86+
87+
@property
88+
def eos_token_id(self):
89+
return self._get_token_id(self.eos_token)
90+
91+
@property
92+
def mask_token(self):
93+
return self._get_token("mask_token")
94+
95+
@property
96+
def mask_token_id(self):
97+
return self._get_token_id(self.mask_token)
98+
99+
@property
100+
def pad_token(self):
101+
return self._get_token("pad_token")
102+
103+
@property
104+
def pad_token_id(self):
105+
return self._get_token_id(self.pad_token)
106+
75107
@property
76108
def chain_break_token(self):
77109
return self.cb_token
78110

79111
@property
80112
def chain_break_token_id(self):
81-
return self.convert_tokens_to_ids(self.chain_break_token)
113+
return self._get_token_id(self.chain_break_token)
82114

83115
@property
84116
def all_token_ids(self):
@@ -87,3 +119,16 @@ def all_token_ids(self):
87119
@property
88120
def special_token_ids(self):
89121
return self.all_special_ids
122+
123+
def _get_token_id(self, token) -> int:
124+
token_id = self.convert_tokens_to_ids(token)
125+
assert isinstance(token_id, int)
126+
return token_id
127+
128+
def _get_token(self, token_name: str) -> str:
129+
# NOTE: Tokenizers library overloads __getattr__ to expose special tokens
130+
# Adding a helper method around it keeps the base class functionality without overriding
131+
# the property. See: https://github.com/huggingface/transformers/blob/41925e42135257361b7f02aa20e3bbdab3f7b923/src/transformers/tokenization_utils_base.py#L1086
132+
token_str = self.__getattr__(token_name)
133+
assert isinstance(token_str, str)
134+
return token_str

0 commit comments

Comments
 (0)