Skip to content

Commit befcd8b

Browse files
claudeyclaude
andcommitted
Update docs with pipeline API, AD-compatible ESM2, and constants
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a183e6f commit befcd8b

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

README.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,87 @@ pae = metrics.predicted_aligned_error
8484
max_pae = metrics.max_predicted_aligned_error
8585
```
8686

87+
## Pipeline API
88+
89+
In addition to the monolithic `infer()`, ESMFold.jl exports composable pipeline stages
90+
that give you access to intermediate representations. All functions work on both CPU and
91+
GPU — tensors follow the model device automatically.
92+
93+
### Pipeline overview
94+
95+
```
96+
prepare_inputs → run_embedding → run_trunk → run_heads → (post‑processing)
97+
╰─ run_esm2 ╰─ run_trunk_single_pass
98+
╰─ run_structure_module
99+
```
100+
101+
`run_pipeline(model, sequences)` chains all stages and produces output identical to
102+
`infer()`. The individual stages can be called separately for research workflows.
103+
104+
### Stage reference
105+
106+
| Function | Input | Output | Description |
107+
|----------|-------|--------|-------------|
108+
| `prepare_inputs(model, seqs)` | sequences | NamedTuple | Encode + device transfer |
109+
| `run_esm2(model, inputs)` | prepared inputs | `ESM2Output` | Raw ESM2 with BOS/EOS wrapping |
110+
| `run_embedding(model, inputs)` | prepared inputs | `(s_s_0, s_z_0)` | ESM2 + projection to trunk dims |
111+
| `run_trunk(model, s_s_0, s_z_0, inputs)` | embeddings | Dict | Full trunk: recycling + structure module |
112+
| `run_trunk_single_pass(model, s_s, s_z, inputs)` | states | `(s_s, s_z)` | One pass through 48 blocks (no recycling) |
113+
| `run_structure_module(model, s_s, s_z, inputs)` | trunk states | Dict | Structure module on custom states |
114+
| `run_heads(model, structure, inputs)` | structure Dict | Dict | Distogram, PTM, lDDT, LM heads |
115+
| `run_pipeline(model, seqs)` | sequences | Dict | Full pipeline (identical to `infer`) |
116+
117+
### Examples
118+
119+
**Get ESM2 embeddings:**
120+
121+
```julia
122+
inputs = prepare_inputs(model, "MKQLLED...")
123+
esm_out = run_esm2(model, inputs; repr_layers=collect(0:33))
124+
esm_out.representations[33] # (B, T, C) last-layer hidden states
125+
```
126+
127+
**Get trunk output without the structure module:**
128+
129+
```julia
130+
inputs = prepare_inputs(model, "MKQLLED...")
131+
emb = run_embedding(model, inputs)
132+
result = run_trunk_single_pass(model, emb.s_s_0, emb.s_z_0, inputs)
133+
result.s_s # (1024, L, B) sequence state
134+
result.s_z # (128, L, L, B) pairwise state
135+
```
136+
137+
**Run structure module on custom features:**
138+
139+
```julia
140+
structure = run_structure_module(model, custom_s_s, custom_s_z, inputs)
141+
```
142+
143+
**Get distograms from one pass:**
144+
145+
```julia
146+
emb = run_embedding(model, inputs)
147+
result = run_trunk_single_pass(model, emb.s_s_0, emb.s_z_0, inputs)
148+
structure = run_structure_module(model, result.s_s, result.s_z, inputs)
149+
output = run_heads(model, structure, inputs)
150+
output[:distogram_logits] # (64, L, L, B)
151+
```
152+
153+
### AD‑compatible ESM2 forward
154+
155+
The standard ESM2 forward uses in‑place GPU ops that Zygote cannot differentiate.
156+
`esm2_forward_ad` provides an allocating replacement:
157+
158+
```julia
159+
using Zygote
160+
161+
# tokens_bt: (B, T) 0-indexed token array (from ESM2's Alphabet conventions)
162+
grads = Zygote.gradient(model.embed.esm) do esm
163+
x = esm2_forward_ad(esm, tokens_bt)
164+
sum(x)
165+
end
166+
```
167+
87168
## Weights And Caching
88169

89170
`load_ESMFold()` downloads the safetensors checkpoint from Hugging Face using

docs/src/index.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,44 @@ Use `output_to_pdb` to export PDBs.
2828
## Input Modes
2929

3030
- `AbstractMatrix{Int}` shaped `(B, L)`
31-
- `Vector{Vector{Int}}` (autopadded)
31+
- `Vector{Vector{Int}}` (auto-padded)
3232
- `Vector{String}` or a single `String`
3333

34-
See the README for more usage examples and batch folding.
34+
## Pipeline API
35+
36+
The inference pipeline is decomposed into composable stages. Each stage can be called
37+
independently for research workflows (extracting embeddings, running partial inference,
38+
feeding custom features, etc.).
39+
40+
```
41+
prepare_inputs → run_embedding → run_trunk → run_heads → (post-processing)
42+
╰─ run_esm2 ╰─ run_trunk_single_pass
43+
╰─ run_structure_module
44+
```
45+
46+
### Stages
47+
48+
- **`prepare_inputs(model, sequences)`** — encode sequences and transfer to model device
49+
- **`run_esm2(model, inputs)`** — raw ESM2 forward with BOS/EOS wrapping
50+
- **`run_embedding(model, inputs)`** — ESM2 + projection to trunk dimensions → `(s_s_0, s_z_0)`
51+
- **`run_trunk(model, s_s_0, s_z_0, inputs)`** — full trunk with recycling + structure module
52+
- **`run_trunk_single_pass(model, s_s, s_z, inputs)`** — one pass through 48 blocks (no recycling, no structure module)
53+
- **`run_structure_module(model, s_s, s_z, inputs)`** — structure module on arbitrary trunk outputs
54+
- **`run_heads(model, structure, inputs)`** — all output heads (distogram, PTM, lDDT, LM)
55+
- **`run_pipeline(model, sequences)`** — full pipeline, identical output to `infer()`
56+
57+
### AD-compatible ESM2
58+
59+
`esm2_forward_ad(esm, tokens_bt)` is a Zygote-compatible ESM2 forward that replaces
60+
in-place ops with allocating equivalents. Use it when you need gradients through the
61+
language model.
62+
63+
### Constants
64+
65+
`DISTOGRAM_BINS`, `LDDT_BINS`, `NUM_ATOM_TYPES`, `RECYCLE_DISTANCE_BINS` — named
66+
constants for model dimensions, replacing magic numbers.
67+
68+
See the README for detailed examples.
3569

3670
```@index
3771
```

0 commit comments

Comments
 (0)