@@ -84,6 +84,87 @@ pae = metrics.predicted_aligned_error
8484max_pae = metrics. max_predicted_aligned_error
8585```
8686
87+ ## Pipeline API
88+
89+ In addition to the monolithic ` infer() ` , ESMFold.jl exports composable pipeline stages
90+ that give you access to intermediate representations. All functions work on both CPU and
91+ GPU — tensors follow the model device automatically.
92+
93+ ### Pipeline overview
94+
95+ ```
96+ prepare_inputs → run_embedding → run_trunk → run_heads → (post‑processing)
97+ ╰─ run_esm2 ╰─ run_trunk_single_pass
98+ ╰─ run_structure_module
99+ ```
100+
101+ ` run_pipeline(model, sequences) ` chains all stages and produces output identical to
102+ ` infer() ` . The individual stages can be called separately for research workflows.
103+
104+ ### Stage reference
105+
106+ | Function | Input | Output | Description |
107+ | ----------| -------| --------| -------------|
108+ | ` prepare_inputs(model, seqs) ` | sequences | NamedTuple | Encode + device transfer |
109+ | ` run_esm2(model, inputs) ` | prepared inputs | ` ESM2Output ` | Raw ESM2 with BOS/EOS wrapping |
110+ | ` run_embedding(model, inputs) ` | prepared inputs | ` (s_s_0, s_z_0) ` | ESM2 + projection to trunk dims |
111+ | ` run_trunk(model, s_s_0, s_z_0, inputs) ` | embeddings | Dict | Full trunk: recycling + structure module |
112+ | ` run_trunk_single_pass(model, s_s, s_z, inputs) ` | states | ` (s_s, s_z) ` | One pass through 48 blocks (no recycling) |
113+ | ` run_structure_module(model, s_s, s_z, inputs) ` | trunk states | Dict | Structure module on custom states |
114+ | ` run_heads(model, structure, inputs) ` | structure Dict | Dict | Distogram, PTM, lDDT, LM heads |
115+ | ` run_pipeline(model, seqs) ` | sequences | Dict | Full pipeline (identical to ` infer ` ) |
116+
117+ ### Examples
118+
119+ ** Get ESM2 embeddings:**
120+
121+ ``` julia
122+ inputs = prepare_inputs (model, " MKQLLED..." )
123+ esm_out = run_esm2 (model, inputs; repr_layers= collect (0 : 33 ))
124+ esm_out. representations[33 ] # (B, T, C) last-layer hidden states
125+ ```
126+
127+ ** Get trunk output without the structure module:**
128+
129+ ``` julia
130+ inputs = prepare_inputs (model, " MKQLLED..." )
131+ emb = run_embedding (model, inputs)
132+ result = run_trunk_single_pass (model, emb. s_s_0, emb. s_z_0, inputs)
133+ result. s_s # (1024, L, B) sequence state
134+ result. s_z # (128, L, L, B) pairwise state
135+ ```
136+
137+ ** Run structure module on custom features:**
138+
139+ ``` julia
140+ structure = run_structure_module (model, custom_s_s, custom_s_z, inputs)
141+ ```
142+
143+ ** Get distograms from one pass:**
144+
145+ ``` julia
146+ emb = run_embedding (model, inputs)
147+ result = run_trunk_single_pass (model, emb. s_s_0, emb. s_z_0, inputs)
148+ structure = run_structure_module (model, result. s_s, result. s_z, inputs)
149+ output = run_heads (model, structure, inputs)
150+ output[:distogram_logits ] # (64, L, L, B)
151+ ```
152+
153+ ### AD‑compatible ESM2 forward
154+
155+ The standard ESM2 forward uses in‑place GPU ops that Zygote cannot differentiate.
156+ ` esm2_forward_ad ` provides an allocating replacement:
157+
158+ ``` julia
159+ using Zygote
160+
161+ # tokens_bt: (B, T) 0-indexed token array (from ESM2's Alphabet conventions)
162+ grads = Zygote. gradient (model. embed. esm) do esm
163+ x = esm2_forward_ad (esm, tokens_bt)
164+ sum (x)
165+ end
166+ ```
167+
87168## Weights And Caching
88169
89170` load_ESMFold() ` downloads the safetensors checkpoint from Hugging Face using
0 commit comments