Skip to content

Commit a183e6f

Browse files
claudeyclaude
andcommitted
Add modular pipeline API: composable inference stages for research use
New exports: prepare_inputs, run_esm2, run_embedding, run_trunk, run_trunk_single_pass, run_structure_module, run_heads, run_pipeline. Also adds esm2_forward_ad (AD-compatible ESM2 forward) and named constants (DISTOGRAM_BINS, LDDT_BINS, NUM_ATOM_TYPES, RECYCLE_DISTANCE_BINS). run_pipeline() produces bit-identical output to infer() — verified on GPU with 0.0 max diff across all 25 output keys. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7dfdac4 commit a183e6f

File tree

5 files changed

+507
-3
lines changed

5 files changed

+507
-3
lines changed

src/ESMFold.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using HuggingFaceApi
1111

1212
include("device_utils.jl")
1313
include("safetensors.jl")
14+
include("constants.jl")
1415

1516
# GPU support
1617
using CUDA
@@ -78,13 +79,20 @@ export confidence_metrics
7879
export set_training!, is_training
7980
export make_atom14_masks!
8081
export compute_tm, compute_predicted_aligned_error, categorical_lddt
82+
export DISTOGRAM_BINS, LDDT_BINS, NUM_ATOM_TYPES, RECYCLE_DISTANCE_BINS
83+
export esm2_forward_ad
84+
export prepare_inputs, run_esm2, run_embedding
85+
export run_trunk, run_trunk_single_pass
86+
export run_structure_module, run_heads, run_pipeline
8187

8288
include("alphabet.jl")
8389
include("esm2.jl")
90+
include("esm2_ad.jl")
8491
include("esmfold_embed.jl")
8592
include("openfold_infer_utils.jl")
8693
include("protein.jl")
8794
include("esmfold_full.jl")
95+
include("pipeline.jl")
8896
include("weights.jl")
8997

9098
end

src/constants.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
const DISTOGRAM_BINS = 64
2+
const LDDT_BINS = 50
3+
const NUM_ATOM_TYPES = 37
4+
const RECYCLE_DISTANCE_BINS = 15

src/esm2_ad.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
esm2_forward_ad(esm::ESM2, tokens_bt::AbstractArray{Int,2}) → AbstractArray
3+
4+
AD-compatible forward pass through ESM2, returning final hidden states.
5+
6+
The standard ESM2 forward uses in-place ops (`.+=`, `.=`) that Zygote cannot differentiate.
7+
This function replaces all in-place operations with allocating equivalents, enabling
8+
`Zygote.gradient` through the full 33-layer transformer.
9+
10+
# Arguments
11+
- `esm::ESM2`: The ESM2 language model.
12+
- `tokens_bt::AbstractArray{Int,2}`: Token indices in `(B, T)` layout, 0-indexed
13+
(matching `Alphabet` conventions where `padding_idx` is typically 1).
14+
15+
# Returns
16+
- `AbstractArray` of shape `(C, T, B)` — final hidden states after `emb_layer_norm_after`.
17+
18+
# Example
19+
```julia
20+
using Zygote
21+
grads = Zygote.gradient(model.embed.esm) do esm
22+
x = esm2_forward_ad(esm, tokens_bt)
23+
sum(x)
24+
end
25+
```
26+
27+
# Notes
28+
- Token dropout is NOT applied (inference path only).
29+
- Per-layer representations are NOT collected; only the final output is returned.
30+
- Padding mask and precomputed attention bias are wrapped in `@ignore_derivatives`
31+
so they are treated as opaque constants by the AD system.
32+
"""
33+
function esm2_forward_ad(esm::ESM2, tokens_bt::AbstractArray{Int,2})
34+
pad_idx = esm.alphabet.padding_idx
35+
36+
tokens_tb = permutedims(tokens_bt, (2, 1)) # (T, B)
37+
x = esm.embed_scale .* esm.embed_tokens(tokens_tb .+ 1)
38+
39+
# Padding mask — treat as constant for AD
40+
padding_mask_bt = @ignore_derivatives tokens_bt .== pad_idx
41+
has_padding = @ignore_derivatives any(padding_mask_bt)
42+
43+
# Apply padding mask (allocating, not in-place)
44+
if has_padding
45+
padding_mask_tb = @ignore_derivatives permutedims(padding_mask_bt, (2, 1))
46+
x = x .* reshape(1 .- padding_mask_tb, 1, size(x, 2), size(x, 3))
47+
end
48+
49+
# Pre-compute attention bias (constant — no gradient needed)
50+
precomputed_attn_bias = @ignore_derivatives begin
51+
if has_padding
52+
pad_tb = permutedims(padding_mask_bt, (2, 1))
53+
pad_bias = ifelse.(pad_tb, Float32(-Inf), 0f0)
54+
pad_bias_4d = reshape(pad_bias, size(pad_bias, 1), 1, 1, size(pad_bias, 2))
55+
repeat(pad_bias_4d, 1, size(tokens_tb, 1), esm.attention_heads, 1)
56+
else
57+
nothing
58+
end
59+
end
60+
61+
padding_mask_for_attn = @ignore_derivatives has_padding ? padding_mask_bt : nothing
62+
63+
# Run all transformer layers (allocating residual path — no .+=)
64+
for layer_idx in 1:esm.num_layers
65+
layer = esm.layers[layer_idx]
66+
67+
residual = x
68+
x = layer.self_attn_layer_norm(x)
69+
attn_out, _ = layer.self_attn(
70+
x;
71+
key_padding_mask = padding_mask_for_attn,
72+
need_head_weights = false,
73+
_precomputed_attn_bias = precomputed_attn_bias,
74+
)
75+
x = residual .+ attn_out
76+
77+
residual = x
78+
x = layer.final_layer_norm(x)
79+
x = NNlib.gelu.(layer.fc1(x))
80+
x = layer.fc2(x)
81+
x = residual .+ x
82+
end
83+
84+
x = esm.emb_layer_norm_after(x)
85+
return x # (C, T, B)
86+
end

src/esmfold_full.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ function ESMFoldModel(
6464
)
6565
trunk = FoldingTrunk(cfg=cfg.trunk)
6666

67-
distogram_bins = 64
67+
distogram_bins = DISTOGRAM_BINS
6868
distogram_head = LinearFirst(c_z, distogram_bins)
6969
ptm_head = LinearFirst(c_z, distogram_bins)
7070
lm_head = LinearFirst(c_s, embed.n_tokens_embed)
71-
lddt_bins = 50
72-
lddt_head = ESMFoldLDDTHead(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim, 37 * lddt_bins)
71+
lddt_bins = LDDT_BINS
72+
lddt_head = ESMFoldLDDTHead(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim, NUM_ATOM_TYPES * lddt_bins)
7373

7474
return ESMFoldModel(
7575
cfg,

0 commit comments

Comments
 (0)