Skip to content

Optimized#5

Merged
murrellb merged 9 commits intomainfrom
optimized
Feb 12, 2026
Merged

Optimized#5
murrellb merged 9 commits intomainfrom
optimized

Conversation

@murrellb
Copy link
Member

No description provided.

murrellb and others added 9 commits February 6, 2026 23:45
Fix several issues that prevented the model from running on GPU arrays:
- Use copyto! instead of broadcast in to_device for cross-device transfer
- Widen ESM2Output types from Array to AbstractArray so GPU arrays pass through
- Vectorize EOS token placement in esmfold_embed to avoid scalar indexing
- Pull PTM computation to CPU since it requires per-sequence slicing
- Compute PAE bin centers on CPU to avoid scalar indexing on GPU arrays
- Reorder output_to_pdb to convert to CPU before atom14_to_atom37
- Update README with installation instructions and GPU usage docs
Custom GPU kernels and memory optimizations that bring Julia ESMFold
from ~4.7s to 1.85s (batch=1, 127-residue), beating the Python/PyTorch
reference of 2.69s on NVIDIA GB10 (Blackwell).

Key changes:
- Flash multi-head attention (cuTile): fused QK^T → softmax → V with
  online softmax and TF32 tensor cores. Includes OOB K-position masking
  via ct.arange bounds check on the last tile to prevent softmax corruption.
  No-bias variant for ESM2, bias variant for ESMFold/Triangle attention.
- Fused LayerNorm (cuTile): single kernel replacing ~9 launches + 6
  intermediate allocations. In-place variant for pre-allocated buffers.
- Pre-allocated buffer pool (_perm_buf_pool): reusable CuArray cache for
  permutedims!, matmul outputs, and attention intermediates (~15 slots).
- CUDA.unsafe_free!: explicit GPU memory release across all components,
  reducing per-block time from 16.5ms to 7.7ms.
- Fused linear projections: concatenated weight matrices for single matmul
  in TriangleMultiplicativeUpdate (4 projections → 1).
- cuTENSOR tensor contractions with cached plans for _combine_projections.
- Fused rotary position embeddings (2 broadcasts instead of ~5 kernels).

Parity: max 0.18A coordinate difference from CPU Float32 reference
(TF32 accumulation). Deterministic across GPU runs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move reusable protein layers to Onion.jl and GPU kernels to OnionTile.jl.
ESMFold.jl now only contains model-specific code:
- ESM2 language model, Alphabet, weight loading, inference
- ESMFoldEmbed (depends on ESM2), ESMFoldModel
- PDB output, SafeTensors reader

Removed 13 source files (-3226 lines), cleaned up unused deps
(BatchedTransformations, Einops, SpecialFunctions, cuTile, cuTENSOR).
Added InvariantPointAttention -> ESMFoldIPA backward compat alias.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
OnionTile is a pure side-effect module (CuArray method overrides).
Now that Onion.jl has AnyGPUArray dispatch hooks routing to ONIONop
KA kernels, ESMFold can run on any GPU without OnionTile loaded.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
OnionTile should not be a hard dependency of ESMFold.jl. GPU kernels are
loaded optionally via the GPU_test project environment.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
New exports: prepare_inputs, run_esm2, run_embedding, run_trunk,
run_trunk_single_pass, run_structure_module, run_heads, run_pipeline.
Also adds esm2_forward_ad (AD-compatible ESM2 forward) and named
constants (DISTOGRAM_BINS, LDDT_BINS, NUM_ATOM_TYPES, RECYCLE_DISTANCE_BINS).

run_pipeline() produces bit-identical output to infer() — verified on GPU
with 0.0 max diff across all 25 output keys.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@murrellb murrellb merged commit 52fed45 into main Feb 12, 2026
0 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant