Skip to content

Commit 52fed45

Browse files
authored
Merge pull request #5 from MurrellGroup/optimized
Optimized
2 parents 678afa6 + befcd8b commit 52fed45

24 files changed

+778
-2250
lines changed

Project.toml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
name = "ESMFold"
22
uuid = "992ade7d-a23b-44e1-bfa9-f18d3a6e7567"
3-
authors = ["Ben Murrell"]
43
version = "1.0.0-DEV"
4+
authors = ["Ben Murrell"]
55

66
[deps]
77
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
8-
BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
8+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
10-
Einops = "e3ce28c8-8bfb-4704-8add-e3e7f14b55c9"
1110
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1211
HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de"
1312
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
@@ -17,22 +16,19 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1716
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
1817
Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
1918
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
20-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
21-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
19+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2320

2421
[compat]
2522
BFloat16s = "0.5, 0.6"
23+
CUDA = "5.9.6"
2624
ChainRulesCore = "1.26.0"
27-
Einops = "0.1.14"
2825
Flux = "0.16.9"
2926
HuggingFaceApi = "0.1"
3027
JSON = "1.4.0"
3128
NNlib = "0.9.33"
3229
NPZ = "0.4.3"
3330
Onion = "0.2.18"
34-
SpecialFunctions = "2"
35-
Zygote = "0.7.10"
31+
cuDNN = "1.4.6"
3632
julia = "1.10"
3733

3834
[extras]

README.md

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,28 @@ A Julia port of the **full ESMFold model**: ESM2 embeddings + folding trunk + st
44
This repo runs end‑to‑end folding on CPU, and will run on GPU when you move the model/tensors
55
to the GPU.
66

7+
## Installation
8+
9+
Some dependencies (Onion, Einops, BatchedTransformations, etc.) live in the
10+
MurrellGroup registry. Add it **once** alongside the default General registry:
11+
12+
```julia
13+
using Pkg
14+
Pkg.Registry.add("https://github.com/MurrellGroup/MurrellGroupRegistry")
15+
```
16+
17+
Then install ESMFold.jl (from a local clone):
18+
19+
```julia
20+
Pkg.develop(path="path/to/ESMFold.jl")
21+
```
22+
23+
Or, if the package is registered in the MurrellGroup registry:
24+
25+
```julia
26+
Pkg.add("ESMFold")
27+
```
28+
729
## Quickstart (single sequence)
830

931
```julia
@@ -62,6 +84,87 @@ pae = metrics.predicted_aligned_error
6284
max_pae = metrics.max_predicted_aligned_error
6385
```
6486

87+
## Pipeline API
88+
89+
In addition to the monolithic `infer()`, ESMFold.jl exports composable pipeline stages
90+
that give you access to intermediate representations. All functions work on both CPU and
91+
GPU — tensors follow the model device automatically.
92+
93+
### Pipeline overview
94+
95+
```
96+
prepare_inputs → run_embedding → run_trunk → run_heads → (post‑processing)
97+
╰─ run_esm2 ╰─ run_trunk_single_pass
98+
╰─ run_structure_module
99+
```
100+
101+
`run_pipeline(model, sequences)` chains all stages and produces output identical to
102+
`infer()`. The individual stages can be called separately for research workflows.
103+
104+
### Stage reference
105+
106+
| Function | Input | Output | Description |
107+
|----------|-------|--------|-------------|
108+
| `prepare_inputs(model, seqs)` | sequences | NamedTuple | Encode + device transfer |
109+
| `run_esm2(model, inputs)` | prepared inputs | `ESM2Output` | Raw ESM2 with BOS/EOS wrapping |
110+
| `run_embedding(model, inputs)` | prepared inputs | `(s_s_0, s_z_0)` | ESM2 + projection to trunk dims |
111+
| `run_trunk(model, s_s_0, s_z_0, inputs)` | embeddings | Dict | Full trunk: recycling + structure module |
112+
| `run_trunk_single_pass(model, s_s, s_z, inputs)` | states | `(s_s, s_z)` | One pass through 48 blocks (no recycling) |
113+
| `run_structure_module(model, s_s, s_z, inputs)` | trunk states | Dict | Structure module on custom states |
114+
| `run_heads(model, structure, inputs)` | structure Dict | Dict | Distogram, PTM, lDDT, LM heads |
115+
| `run_pipeline(model, seqs)` | sequences | Dict | Full pipeline (identical to `infer`) |
116+
117+
### Examples
118+
119+
**Get ESM2 embeddings:**
120+
121+
```julia
122+
inputs = prepare_inputs(model, "MKQLLED...")
123+
esm_out = run_esm2(model, inputs; repr_layers=collect(0:33))
124+
esm_out.representations[33] # (B, T, C) last-layer hidden states
125+
```
126+
127+
**Get trunk output without the structure module:**
128+
129+
```julia
130+
inputs = prepare_inputs(model, "MKQLLED...")
131+
emb = run_embedding(model, inputs)
132+
result = run_trunk_single_pass(model, emb.s_s_0, emb.s_z_0, inputs)
133+
result.s_s # (1024, L, B) sequence state
134+
result.s_z # (128, L, L, B) pairwise state
135+
```
136+
137+
**Run structure module on custom features:**
138+
139+
```julia
140+
structure = run_structure_module(model, custom_s_s, custom_s_z, inputs)
141+
```
142+
143+
**Get distograms from one pass:**
144+
145+
```julia
146+
emb = run_embedding(model, inputs)
147+
result = run_trunk_single_pass(model, emb.s_s_0, emb.s_z_0, inputs)
148+
structure = run_structure_module(model, result.s_s, result.s_z, inputs)
149+
output = run_heads(model, structure, inputs)
150+
output[:distogram_logits] # (64, L, L, B)
151+
```
152+
153+
### AD‑compatible ESM2 forward
154+
155+
The standard ESM2 forward uses in‑place GPU ops that Zygote cannot differentiate.
156+
`esm2_forward_ad` provides an allocating replacement:
157+
158+
```julia
159+
using Zygote
160+
161+
# tokens_bt: (B, T) 0-indexed token array (from ESM2's Alphabet conventions)
162+
grads = Zygote.gradient(model.embed.esm) do esm
163+
x = esm2_forward_ad(esm, tokens_bt)
164+
sum(x)
165+
end
166+
```
167+
65168
## Weights And Caching
66169

67170
`load_ESMFold()` downloads the safetensors checkpoint from Hugging Face using
@@ -97,9 +200,30 @@ resulting PDB against `scripts/output_ELLKKLLEELKG.pdb`:
97200
julia --project=. scripts/test.jl
98201
```
99202

203+
## GPU Inference
204+
205+
ESMFold.jl has no direct CUDA dependency. To run on GPU, add `CUDA.jl` and
206+
`cuDNN.jl` to your own project environment, move the model with `Flux.gpu`,
207+
and call `infer` as usual:
208+
209+
```julia
210+
using CUDA, cuDNN
211+
using Flux
212+
using ESMFold
213+
214+
model = load_ESMFold()
215+
gpu_model = Flux.gpu(model)
216+
217+
output = infer(gpu_model, "ELLKKLLEELKG")
218+
pdb = output_to_pdb(output)[1]
219+
```
220+
221+
All intermediate tensors automatically follow the model to the GPU.
222+
`output_to_pdb` handles moving results back to CPU.
223+
100224
## Notes
101225

102-
- CPU‑only execution is supported.
226+
- Both CPU and GPU execution are supported.
103227
- The implementation follows the ESMFold Python model closely and is parity‑checked
104228
against the official model within floating‑point tolerances.
105229

docs/src/index.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,44 @@ Use `output_to_pdb` to export PDBs.
2828
## Input Modes
2929

3030
- `AbstractMatrix{Int}` shaped `(B, L)`
31-
- `Vector{Vector{Int}}` (autopadded)
31+
- `Vector{Vector{Int}}` (auto-padded)
3232
- `Vector{String}` or a single `String`
3333

34-
See the README for more usage examples and batch folding.
34+
## Pipeline API
35+
36+
The inference pipeline is decomposed into composable stages. Each stage can be called
37+
independently for research workflows (extracting embeddings, running partial inference,
38+
feeding custom features, etc.).
39+
40+
```
41+
prepare_inputs → run_embedding → run_trunk → run_heads → (post-processing)
42+
╰─ run_esm2 ╰─ run_trunk_single_pass
43+
╰─ run_structure_module
44+
```
45+
46+
### Stages
47+
48+
- **`prepare_inputs(model, sequences)`** — encode sequences and transfer to model device
49+
- **`run_esm2(model, inputs)`** — raw ESM2 forward with BOS/EOS wrapping
50+
- **`run_embedding(model, inputs)`** — ESM2 + projection to trunk dimensions → `(s_s_0, s_z_0)`
51+
- **`run_trunk(model, s_s_0, s_z_0, inputs)`** — full trunk with recycling + structure module
52+
- **`run_trunk_single_pass(model, s_s, s_z, inputs)`** — one pass through 48 blocks (no recycling, no structure module)
53+
- **`run_structure_module(model, s_s, s_z, inputs)`** — structure module on arbitrary trunk outputs
54+
- **`run_heads(model, structure, inputs)`** — all output heads (distogram, PTM, lDDT, LM)
55+
- **`run_pipeline(model, sequences)`** — full pipeline, identical output to `infer()`
56+
57+
### AD-compatible ESM2
58+
59+
`esm2_forward_ad(esm, tokens_bt)` is a Zygote-compatible ESM2 forward that replaces
60+
in-place ops with allocating equivalents. Use it when you need gradients through the
61+
language model.
62+
63+
### Constants
64+
65+
`DISTOGRAM_BINS`, `LDDT_BINS`, `NUM_ATOM_TYPES`, `RECYCLE_DISTANCE_BINS` — named
66+
constants for model dimensions, replacing magic numbers.
67+
68+
See the README for detailed examples.
3569

3670
```@index
3771
```

src/ESMFold.jl

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,69 @@
11
module ESMFold
22

33
using LinearAlgebra
4-
using Statistics
54

6-
using Einops
75
using Flux
86
using NNlib
97
using Onion
108
using NPZ
119
using JSON
12-
using SpecialFunctions
1310
using HuggingFaceApi
1411

1512
include("device_utils.jl")
1613
include("safetensors.jl")
14+
include("constants.jl")
15+
16+
# GPU support
17+
using CUDA
18+
19+
# Import set_chunk_size! for extension on ESMFoldModel
20+
import Onion: set_chunk_size!
21+
22+
# Import all protein layer types and utilities from Onion
23+
using Onion: LayerNormFirst, LinearFirst, layernorm_inplace!,
24+
AbstractRotation, RotMatRotation, QuatRotation, Rigid,
25+
rot_from_mat, rot_from_quat, rot_matmul_first, rot_vec_mul_first,
26+
quat_multiply_first, quat_multiply_by_vec_first, quat_to_rot_first,
27+
rotation_identity, get_rot_mats, get_quats,
28+
compose_q_update_vec, rigid_identity, apply_rotation, apply_rigid,
29+
invert_apply_rigid, compose, scale_translation, to_tensor_7, to_tensor_4x4,
30+
rigid_index,
31+
permute_final_dims, flatten_final_dims, dict_multimap, stack_dicts,
32+
one_hot_last, collate_dense_tensors,
33+
rigid_from_tensor_4x4, torsion_angles_to_frames,
34+
frames_and_literature_positions_to_atom14_pos, atom14_to_atom37,
35+
RotaryEmbedding, rotate_half, apply_rotary_pos_emb,
36+
_update_cos_sin!,
37+
ESMFoldAttention, SharedDropout, SequenceToPair, PairToSequence, ResidueMLP,
38+
set_training!, is_training,
39+
encode_sequence, batch_encode_sequences,
40+
CategoricalMixture, categorical_lddt,
41+
ESMMultiheadAttention, _reshape_heads, _unshape_heads, _apply_key_padding_mask!,
42+
OFMultiheadAttention, TriangleAttention, TriangleMultiplicativeUpdate,
43+
TriangleMultiplicationOutgoing, TriangleMultiplicationIncoming,
44+
TriangularSelfAttentionBlock,
45+
StructureModuleConfig, PointProjection, ESMFoldIPA,
46+
BackboneUpdate, StructureModuleTransitionLayer, StructureModuleTransition,
47+
AngleResnetBlock, AngleResnet, StructureModule,
48+
_init_residue_constants!,
49+
ESMFoldEmbedConfig, LayerNormMLP,
50+
FoldingTrunkConfig, RelativePosition, FoldingTrunk,
51+
cross_first, distogram, set_chunk_size!,
52+
layernorm_first_forward, flash_attention_forward, flash_attention_bias_forward,
53+
rotary_pos_emb_forward, combine_projections_forward,
54+
_view_first1, _view_first2, _view_last1, _view_last2,
55+
_reshape_first_corder
56+
57+
# Import residue constants from Onion
58+
using Onion: restypes, restypes_with_x, restype_order, restype_order_with_x, restype_num,
59+
atom_types, atom_order, atom_type_num,
60+
restype_1to3, restype_3to1,
61+
residue_atoms, restype_name_to_atom14_names,
62+
restype_atom14_to_rigid_group, restype_atom14_mask,
63+
restype_atom14_rigid_group_positions, restype_rigid_group_default_frame
64+
65+
# Alias: InvariantPointAttention → ESMFoldIPA for backward compatibility
66+
const InvariantPointAttention = ESMFoldIPA
1767

1868
export Alphabet, RestypeTable, Alphabet_from_architecture
1969
export ESM2Config, ESM2
@@ -29,24 +79,20 @@ export confidence_metrics
2979
export set_training!, is_training
3080
export make_atom14_masks!
3181
export compute_tm, compute_predicted_aligned_error, categorical_lddt
82+
export DISTOGRAM_BINS, LDDT_BINS, NUM_ATOM_TYPES, RECYCLE_DISTANCE_BINS
83+
export esm2_forward_ad
84+
export prepare_inputs, run_esm2, run_embedding
85+
export run_trunk, run_trunk_single_pass
86+
export run_structure_module, run_heads, run_pipeline
3287

3388
include("alphabet.jl")
34-
include("residue_constants.jl")
35-
include("openfold_utils.jl")
36-
include("layernorm.jl")
37-
include("rigid.jl")
38-
include("openfold_feats.jl")
39-
include("esmfold_misc.jl")
40-
include("rotary.jl")
41-
include("attention.jl")
4289
include("esm2.jl")
90+
include("esm2_ad.jl")
4391
include("esmfold_embed.jl")
44-
include("triangular.jl")
45-
include("structure_module.jl")
46-
include("folding_trunk.jl")
4792
include("openfold_infer_utils.jl")
4893
include("protein.jl")
4994
include("esmfold_full.jl")
95+
include("pipeline.jl")
5096
include("weights.jl")
5197

5298
end

0 commit comments

Comments
 (0)