Skip to content
Merged
14 changes: 5 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
name = "ESMFold"
uuid = "992ade7d-a23b-44e1-bfa9-f18d3a6e7567"
authors = ["Ben Murrell"]
version = "1.0.0-DEV"
authors = ["Ben Murrell"]

[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Einops = "e3ce28c8-8bfb-4704-8add-e3e7f14b55c9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Expand All @@ -17,22 +16,19 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
BFloat16s = "0.5, 0.6"
CUDA = "5.9.6"
ChainRulesCore = "1.26.0"
Einops = "0.1.14"
Flux = "0.16.9"
HuggingFaceApi = "0.1"
JSON = "1.4.0"
NNlib = "0.9.33"
NPZ = "0.4.3"
Onion = "0.2.18"
SpecialFunctions = "2"
Zygote = "0.7.10"
cuDNN = "1.4.6"
julia = "1.10"

[extras]
Expand Down
126 changes: 125 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@ A Julia port of the **full ESMFold model**: ESM2 embeddings + folding trunk + st
This repo runs end‑to‑end folding on CPU, and will run on GPU when you move the model/tensors
to the GPU.

## Installation

Some dependencies (Onion, Einops, BatchedTransformations, etc.) live in the
MurrellGroup registry. Add it **once** alongside the default General registry:

```julia
using Pkg
Pkg.Registry.add("https://github.com/MurrellGroup/MurrellGroupRegistry")
```

Then install ESMFold.jl (from a local clone):

```julia
Pkg.develop(path="path/to/ESMFold.jl")
```

Or, if the package is registered in the MurrellGroup registry:

```julia
Pkg.add("ESMFold")
```

## Quickstart (single sequence)

```julia
Expand Down Expand Up @@ -62,6 +84,87 @@ pae = metrics.predicted_aligned_error
max_pae = metrics.max_predicted_aligned_error
```

## Pipeline API

In addition to the monolithic `infer()`, ESMFold.jl exports composable pipeline stages
that give you access to intermediate representations. All functions work on both CPU and
GPU — tensors follow the model device automatically.

### Pipeline overview

```
prepare_inputs → run_embedding → run_trunk → run_heads → (post‑processing)
╰─ run_esm2 ╰─ run_trunk_single_pass
╰─ run_structure_module
```

`run_pipeline(model, sequences)` chains all stages and produces output identical to
`infer()`. The individual stages can be called separately for research workflows.

### Stage reference

| Function | Input | Output | Description |
|----------|-------|--------|-------------|
| `prepare_inputs(model, seqs)` | sequences | NamedTuple | Encode + device transfer |
| `run_esm2(model, inputs)` | prepared inputs | `ESM2Output` | Raw ESM2 with BOS/EOS wrapping |
| `run_embedding(model, inputs)` | prepared inputs | `(s_s_0, s_z_0)` | ESM2 + projection to trunk dims |
| `run_trunk(model, s_s_0, s_z_0, inputs)` | embeddings | Dict | Full trunk: recycling + structure module |
| `run_trunk_single_pass(model, s_s, s_z, inputs)` | states | `(s_s, s_z)` | One pass through 48 blocks (no recycling) |
| `run_structure_module(model, s_s, s_z, inputs)` | trunk states | Dict | Structure module on custom states |
| `run_heads(model, structure, inputs)` | structure Dict | Dict | Distogram, PTM, lDDT, LM heads |
| `run_pipeline(model, seqs)` | sequences | Dict | Full pipeline (identical to `infer`) |

### Examples

**Get ESM2 embeddings:**

```julia
inputs = prepare_inputs(model, "MKQLLED...")
esm_out = run_esm2(model, inputs; repr_layers=collect(0:33))
esm_out.representations[33] # (B, T, C) last-layer hidden states
```

**Get trunk output without the structure module:**

```julia
inputs = prepare_inputs(model, "MKQLLED...")
emb = run_embedding(model, inputs)
result = run_trunk_single_pass(model, emb.s_s_0, emb.s_z_0, inputs)
result.s_s # (1024, L, B) sequence state
result.s_z # (128, L, L, B) pairwise state
```

**Run structure module on custom features:**

```julia
structure = run_structure_module(model, custom_s_s, custom_s_z, inputs)
```

**Get distograms from one pass:**

```julia
emb = run_embedding(model, inputs)
result = run_trunk_single_pass(model, emb.s_s_0, emb.s_z_0, inputs)
structure = run_structure_module(model, result.s_s, result.s_z, inputs)
output = run_heads(model, structure, inputs)
output[:distogram_logits] # (64, L, L, B)
```

### AD‑compatible ESM2 forward

The standard ESM2 forward uses in‑place GPU ops that Zygote cannot differentiate.
`esm2_forward_ad` provides an allocating replacement:

```julia
using Zygote

# tokens_bt: (B, T) 0-indexed token array (from ESM2's Alphabet conventions)
grads = Zygote.gradient(model.embed.esm) do esm
x = esm2_forward_ad(esm, tokens_bt)
sum(x)
end
```

## Weights And Caching

`load_ESMFold()` downloads the safetensors checkpoint from Hugging Face using
Expand Down Expand Up @@ -97,9 +200,30 @@ resulting PDB against `scripts/output_ELLKKLLEELKG.pdb`:
julia --project=. scripts/test.jl
```

## GPU Inference

ESMFold.jl has no direct CUDA dependency. To run on GPU, add `CUDA.jl` and
`cuDNN.jl` to your own project environment, move the model with `Flux.gpu`,
and call `infer` as usual:

```julia
using CUDA, cuDNN
using Flux
using ESMFold

model = load_ESMFold()
gpu_model = Flux.gpu(model)

output = infer(gpu_model, "ELLKKLLEELKG")
pdb = output_to_pdb(output)[1]
```

All intermediate tensors automatically follow the model to the GPU.
`output_to_pdb` handles moving results back to CPU.

## Notes

- CPU‑only execution is supported.
- Both CPU and GPU execution are supported.
- The implementation follows the ESMFold Python model closely and is parity‑checked
against the official model within floating‑point tolerances.

Expand Down
38 changes: 36 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,44 @@ Use `output_to_pdb` to export PDBs.
## Input Modes

- `AbstractMatrix{Int}` shaped `(B, L)`
- `Vector{Vector{Int}}` (autopadded)
- `Vector{Vector{Int}}` (auto-padded)
- `Vector{String}` or a single `String`

See the README for more usage examples and batch folding.
## Pipeline API

The inference pipeline is decomposed into composable stages. Each stage can be called
independently for research workflows (extracting embeddings, running partial inference,
feeding custom features, etc.).

```
prepare_inputs → run_embedding → run_trunk → run_heads → (post-processing)
╰─ run_esm2 ╰─ run_trunk_single_pass
╰─ run_structure_module
```

### Stages

- **`prepare_inputs(model, sequences)`** — encode sequences and transfer to model device
- **`run_esm2(model, inputs)`** — raw ESM2 forward with BOS/EOS wrapping
- **`run_embedding(model, inputs)`** — ESM2 + projection to trunk dimensions → `(s_s_0, s_z_0)`
- **`run_trunk(model, s_s_0, s_z_0, inputs)`** — full trunk with recycling + structure module
- **`run_trunk_single_pass(model, s_s, s_z, inputs)`** — one pass through 48 blocks (no recycling, no structure module)
- **`run_structure_module(model, s_s, s_z, inputs)`** — structure module on arbitrary trunk outputs
- **`run_heads(model, structure, inputs)`** — all output heads (distogram, PTM, lDDT, LM)
- **`run_pipeline(model, sequences)`** — full pipeline, identical output to `infer()`

### AD-compatible ESM2

`esm2_forward_ad(esm, tokens_bt)` is a Zygote-compatible ESM2 forward that replaces
in-place ops with allocating equivalents. Use it when you need gradients through the
language model.

### Constants

`DISTOGRAM_BINS`, `LDDT_BINS`, `NUM_ATOM_TYPES`, `RECYCLE_DISTANCE_BINS` — named
constants for model dimensions, replacing magic numbers.

See the README for detailed examples.

```@index
```
Expand Down
74 changes: 60 additions & 14 deletions src/ESMFold.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,69 @@
module ESMFold

using LinearAlgebra
using Statistics

using Einops
using Flux
using NNlib
using Onion
using NPZ
using JSON
using SpecialFunctions
using HuggingFaceApi

include("device_utils.jl")
include("safetensors.jl")
include("constants.jl")

# GPU support
using CUDA

# Import set_chunk_size! for extension on ESMFoldModel
import Onion: set_chunk_size!

# Import all protein layer types and utilities from Onion
using Onion: LayerNormFirst, LinearFirst, layernorm_inplace!,
AbstractRotation, RotMatRotation, QuatRotation, Rigid,
rot_from_mat, rot_from_quat, rot_matmul_first, rot_vec_mul_first,
quat_multiply_first, quat_multiply_by_vec_first, quat_to_rot_first,
rotation_identity, get_rot_mats, get_quats,
compose_q_update_vec, rigid_identity, apply_rotation, apply_rigid,
invert_apply_rigid, compose, scale_translation, to_tensor_7, to_tensor_4x4,
rigid_index,
permute_final_dims, flatten_final_dims, dict_multimap, stack_dicts,
one_hot_last, collate_dense_tensors,
rigid_from_tensor_4x4, torsion_angles_to_frames,
frames_and_literature_positions_to_atom14_pos, atom14_to_atom37,
RotaryEmbedding, rotate_half, apply_rotary_pos_emb,
_update_cos_sin!,
ESMFoldAttention, SharedDropout, SequenceToPair, PairToSequence, ResidueMLP,
set_training!, is_training,
encode_sequence, batch_encode_sequences,
CategoricalMixture, categorical_lddt,
ESMMultiheadAttention, _reshape_heads, _unshape_heads, _apply_key_padding_mask!,
OFMultiheadAttention, TriangleAttention, TriangleMultiplicativeUpdate,
TriangleMultiplicationOutgoing, TriangleMultiplicationIncoming,
TriangularSelfAttentionBlock,
StructureModuleConfig, PointProjection, ESMFoldIPA,
BackboneUpdate, StructureModuleTransitionLayer, StructureModuleTransition,
AngleResnetBlock, AngleResnet, StructureModule,
_init_residue_constants!,
ESMFoldEmbedConfig, LayerNormMLP,
FoldingTrunkConfig, RelativePosition, FoldingTrunk,
cross_first, distogram, set_chunk_size!,
layernorm_first_forward, flash_attention_forward, flash_attention_bias_forward,
rotary_pos_emb_forward, combine_projections_forward,
_view_first1, _view_first2, _view_last1, _view_last2,
_reshape_first_corder

# Import residue constants from Onion
using Onion: restypes, restypes_with_x, restype_order, restype_order_with_x, restype_num,
atom_types, atom_order, atom_type_num,
restype_1to3, restype_3to1,
residue_atoms, restype_name_to_atom14_names,
restype_atom14_to_rigid_group, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_rigid_group_default_frame

# Alias: InvariantPointAttention → ESMFoldIPA for backward compatibility
const InvariantPointAttention = ESMFoldIPA

export Alphabet, RestypeTable, Alphabet_from_architecture
export ESM2Config, ESM2
Expand All @@ -29,24 +79,20 @@ export confidence_metrics
export set_training!, is_training
export make_atom14_masks!
export compute_tm, compute_predicted_aligned_error, categorical_lddt
export DISTOGRAM_BINS, LDDT_BINS, NUM_ATOM_TYPES, RECYCLE_DISTANCE_BINS
export esm2_forward_ad
export prepare_inputs, run_esm2, run_embedding
export run_trunk, run_trunk_single_pass
export run_structure_module, run_heads, run_pipeline

include("alphabet.jl")
include("residue_constants.jl")
include("openfold_utils.jl")
include("layernorm.jl")
include("rigid.jl")
include("openfold_feats.jl")
include("esmfold_misc.jl")
include("rotary.jl")
include("attention.jl")
include("esm2.jl")
include("esm2_ad.jl")
include("esmfold_embed.jl")
include("triangular.jl")
include("structure_module.jl")
include("folding_trunk.jl")
include("openfold_infer_utils.jl")
include("protein.jl")
include("esmfold_full.jl")
include("pipeline.jl")
include("weights.jl")

end
Loading
Loading