Skip to content

Commit f14dfc0

Browse files
committed
Add ESM
1 parent 4b2a0df commit f14dfc0

17 files changed

+3378
-0
lines changed

esm/README.md

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# ESM-2
2+
3+
This repository provides an implementation of Meta's ESM-2 protein language model
4+
in MLX.[^1] ESM-2 is Meta’s second-generation Evolutionary Scale Model, a
5+
transformer-based protein language model trained on millions of diverse protein
6+
sequences with a masked language modeling objective.
7+
8+
![Example contact prediction map](assets/contact_prediction.png)
9+
10+
_Example contact prediction map for a universal stress protein. In this case, ESM-2 650M achieves 86.4% precision at long-range contacts._
11+
12+
## Setup
13+
14+
Install the requirements:
15+
16+
```bash
17+
pip install -r requirements.txt
18+
```
19+
20+
## Usage
21+
22+
Below are the available ESM-2 models:
23+
| Model | Parameters | Layers |
24+
|-------|------------|--------|
25+
| [`esm2_t6_8M_UR50D`](https://huggingface.co/facebook/esm2_t6_8M_UR50D) | 8M | 6 |
26+
| [`esm2_t12_35M_UR50D`](https://huggingface.co/facebook/esm2_t12_35M_UR50D) | 35M | 12 |
27+
| [`esm2_t30_150M_UR50D`](https://huggingface.co/facebook/esm2_t30_150M_UR50D) | 150M | 30 |
28+
| [`esm2_t33_650M_UR50D`](https://huggingface.co/facebook/esm2_t33_650M_UR50D) | 650M | 33 |
29+
| [`esm2_t36_3B_UR50D`](https://huggingface.co/facebook/esm2_t36_3B_UR50D) | 3B | 36 |
30+
| [`esm2_t48_15B_UR50D`](https://huggingface.co/facebook/esm2_t48_15B_UR50D) | 15B | 48 |
31+
32+
Convert a model to MLX format:
33+
34+
```bash
35+
python convert.py --hf-path facebook/esm2_t33_650M_UR50D
36+
```
37+
38+
This will save the converted model in a checkpoints directory.
39+
40+
### Basic Inference
41+
42+
```python
43+
from esm import ESM2
44+
45+
# Load model and tokenizer
46+
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
47+
48+
# Example protein sequence (human insulin)
49+
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
50+
51+
# Tokenize and run inference
52+
tokens = tokenizer.encode(sequence)
53+
result = model(tokens)
54+
logits = result["logits"] # Shape: (batch, length, vocab_size)
55+
```
56+
57+
### Masked Language Modeling
58+
59+
```bash
60+
# For a complete example, see main.py
61+
python main.py --sequence "YOUR_SEQUENCE" --mask-position 50
62+
```
63+
64+
### Embeddings
65+
66+
```python
67+
# Get sequence-level representations
68+
seq_repr = model.get_sequence_representations(tokens, layer=-1) # Shape: (batch, embed_dim)
69+
70+
# Extract per-residue representations from specific layers
71+
representations = model.extract_features(tokens, repr_layers=[20, 30, 33])
72+
final_layer = representations[33] # Shape: (batch, length, embed_dim)
73+
```
74+
75+
### Contact Prediction
76+
77+
```python
78+
# Predict residue-residue contacts
79+
contacts = model.predict_contacts(tokens) # Shape: (batch, length, length)
80+
81+
# Or compute contacts together with logits, representations, etc.
82+
outputs = model(tokens, return_contacts=True)
83+
contacts = outputs["contacts"]
84+
```
85+
86+
### Examples
87+
88+
**Mutation Effect Prediction**: [notebooks/mutation_effect_prediction.ipynb](notebooks/mutation_effect_prediction.ipynb)
89+
90+
This notebook demonstrates how to use ESM-2 for zero-shot mutation effect prediction by scoring amino acid substitutions based on their likelihood under the model. We validate the approach using experimental fitness data from β-lactamase TEM, showing how ESM-2 captures functional constraints without requiring structural information.
91+
92+
**Embeddings**: [notebooks/embeddings.ipynb](notebooks/embeddings.ipynb)
93+
94+
This notebook explores how ESM-2 generates meaningful protein embeddings that capture evolutionary and functional relationships between proteins. We analyze six diverse human proteins to demonstrate how the learned representations cluster proteins by function and reveal biological similarities.
95+
96+
**Contact Prediction**: [notebooks/contact_prediction.ipynb](notebooks/contact_prediction.ipynb)
97+
98+
This notebook shows how to predict residue-residue contacts in protein structures using ESM-2's attention patterns. We evaluate contact prediction performance on three diverse proteins, demonstrating how the model captures both local and long-range structural relationships directly from sequence data.
99+
100+
### Benchmarking
101+
102+
Benchmark MLX performance:
103+
104+
```bash
105+
python benchmarks/benchmark_mx.py
106+
```
107+
108+
Benchmark PyTorch MPS performance:
109+
110+
```bash
111+
python benchmarks/benchmark_pt.py
112+
```
113+
114+
Expected performance on M4 MacBook Pro (ESM-2 650M, batch_size = 5):
115+
116+
- MLX: 299 ms per step, 16.71 sequences/sec
117+
- PyTorch MPS: 402 ms per step, 12.43 sequences/sec
118+
119+
### Testing
120+
121+
Verify correctness against original implementation:
122+
123+
```bash
124+
python test.py
125+
```
126+
127+
This tests tokenizer and model outputs (logits, hidden states, and attentions) for equivalence with the original implementation.
128+
129+
### Citations:
130+
131+
```bibtex
132+
@article{rives2019biological,
133+
author={Rives, Alexander and Meier, Joshua and Sercu, Tom and Goyal, Siddharth and Lin, Zeming and Liu, Jason and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
134+
title={Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
135+
year={2019},
136+
doi={10.1101/622803},
137+
url={https://www.biorxiv.org/content/10.1101/622803v4},
138+
journal={PNAS}
139+
}
140+
141+
```
142+
143+
```bibtex
144+
@article{Lin2023,
145+
author={Zeming Lin et al.},
146+
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
147+
journal={Science},
148+
volume={379},
149+
pages={1123--1130},
150+
year={2023},
151+
doi={10.1126/science.ade2574},
152+
url={https://doi.org/10.1126/science.ade2574}
153+
}
154+
```
155+
156+
[^1]: Refer to the [paper](https://www.science.org/doi/10.1126/science.ade2574) and [code](https://github.com/facebookresearch/esm) for more details.

esm/assets/contact_prediction.png

34.3 KB
Loading

esm/benchmarks/benchmark_mx.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import sys
2+
import time
3+
from pathlib import Path
4+
5+
import mlx.core as mx
6+
7+
# Add parent directory to Python path
8+
cur_path = Path(__file__).parents[1].resolve()
9+
sys.path.append(str(cur_path))
10+
11+
from esm import ESM2
12+
13+
# Example protein sequence (Green Fluorescent Protein)
14+
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
15+
16+
# Load pretrained ESM-2 model and its tokenizer from local checkpoint
17+
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
18+
19+
# Number of sequences to process in each forward pass
20+
batch_size = 5
21+
22+
# Number of timing iterations for performance measurement
23+
steps = 50
24+
25+
# Tokenize the protein sequence into integer IDs for the model
26+
# Replicate the same sequence 'batch_size' times to create a batch
27+
tokens = tokenizer.batch_encode([protein_sequence] * batch_size)
28+
29+
# Warm-up phase
30+
for _ in range(10):
31+
result = model(tokens)
32+
mx.eval(result["logits"]) # Force computation to complete
33+
34+
# Measure average inference time over 'steps' iterations
35+
tic = time.time()
36+
for _ in range(steps):
37+
result = model(tokens)
38+
mx.eval(result["logits"]) # Synchronize and ensure computation finishes
39+
toc = time.time()
40+
41+
# Compute metrics: average time per step (ms) and throughput (sequences/sec)
42+
ms_per_step = 1000 * (toc - tic) / steps
43+
throughput = batch_size * 1000 / ms_per_step
44+
45+
# Display results
46+
print(f"Time (ms) per step: {ms_per_step:.3f}")
47+
print(f"Throughput: {throughput:.2f} sequences/sec")

esm/benchmarks/benchmark_pt.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import time
2+
3+
import torch
4+
from transformers import AutoTokenizer, EsmForMaskedLM
5+
6+
# Example protein sequence (Green Fluorescent Protein)
7+
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
8+
9+
# Hugging Face model identifier for ESM-2 (33 layers, 650M params, UR50D training set)
10+
model_name = "facebook/esm2_t33_650M_UR50D"
11+
12+
# Load tokenizer and model; move model to Apple Metal Performance Shaders (MPS) device
13+
tokenizer = AutoTokenizer.from_pretrained(model_name)
14+
model = EsmForMaskedLM.from_pretrained(model_name).to("mps")
15+
16+
# Number of sequences per forward pass
17+
batch_size = 5
18+
19+
# Number of timing iterations
20+
steps = 50
21+
22+
# Tokenize input sequence and replicate for the batch
23+
# Replicate the same sequence 'batch_size' times to create a batch
24+
inputs = tokenizer(
25+
[protein_sequence] * batch_size,
26+
return_tensors="pt",
27+
padding=True,
28+
truncation=True,
29+
max_length=1024,
30+
)
31+
input_ids = inputs["input_ids"].to("mps")
32+
attention_mask = inputs["attention_mask"].to("mps")
33+
34+
# Warm-up phase
35+
for _ in range(10):
36+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
37+
torch.mps.synchronize() # Ensure all queued ops on MPS are complete before next step
38+
39+
# Timed inference loop
40+
tic = time.time()
41+
for _ in range(steps):
42+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
43+
torch.mps.synchronize() # Wait for computation to finish before timing next iteration
44+
toc = time.time()
45+
46+
# Compute performance metrics
47+
ms_per_step = 1000 * (toc - tic) / steps
48+
throughput = batch_size * 1000 / ms_per_step
49+
50+
# Report results
51+
print(f"Time (ms) per step: {ms_per_step:.3f}")
52+
print(f"Throughput: {throughput:.2f} sequences/sec")

0 commit comments

Comments
 (0)