Conversation
Fix several issues that prevented the model from running on GPU arrays: - Use copyto! instead of broadcast in to_device for cross-device transfer - Widen ESM2Output types from Array to AbstractArray so GPU arrays pass through - Vectorize EOS token placement in esmfold_embed to avoid scalar indexing - Pull PTM computation to CPU since it requires per-sequence slicing - Compute PAE bin centers on CPU to avoid scalar indexing on GPU arrays - Reorder output_to_pdb to convert to CPU before atom14_to_atom37 - Update README with installation instructions and GPU usage docs
Custom GPU kernels and memory optimizations that bring Julia ESMFold from ~4.7s to 1.85s (batch=1, 127-residue), beating the Python/PyTorch reference of 2.69s on NVIDIA GB10 (Blackwell). Key changes: - Flash multi-head attention (cuTile): fused QK^T → softmax → V with online softmax and TF32 tensor cores. Includes OOB K-position masking via ct.arange bounds check on the last tile to prevent softmax corruption. No-bias variant for ESM2, bias variant for ESMFold/Triangle attention. - Fused LayerNorm (cuTile): single kernel replacing ~9 launches + 6 intermediate allocations. In-place variant for pre-allocated buffers. - Pre-allocated buffer pool (_perm_buf_pool): reusable CuArray cache for permutedims!, matmul outputs, and attention intermediates (~15 slots). - CUDA.unsafe_free!: explicit GPU memory release across all components, reducing per-block time from 16.5ms to 7.7ms. - Fused linear projections: concatenated weight matrices for single matmul in TriangleMultiplicativeUpdate (4 projections → 1). - cuTENSOR tensor contractions with cached plans for _combine_projections. - Fused rotary position embeddings (2 broadcasts instead of ~5 kernels). Parity: max 0.18A coordinate difference from CPU Float32 reference (TF32 accumulation). Deterministic across GPU runs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move reusable protein layers to Onion.jl and GPU kernels to OnionTile.jl. ESMFold.jl now only contains model-specific code: - ESM2 language model, Alphabet, weight loading, inference - ESMFoldEmbed (depends on ESM2), ESMFoldModel - PDB output, SafeTensors reader Removed 13 source files (-3226 lines), cleaned up unused deps (BatchedTransformations, Einops, SpecialFunctions, cuTile, cuTENSOR). Added InvariantPointAttention -> ESMFoldIPA backward compat alias. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
OnionTile is a pure side-effect module (CuArray method overrides). Now that Onion.jl has AnyGPUArray dispatch hooks routing to ONIONop KA kernels, ESMFold can run on any GPU without OnionTile loaded. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
OnionTile should not be a hard dependency of ESMFold.jl. GPU kernels are loaded optionally via the GPU_test project environment. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
New exports: prepare_inputs, run_esm2, run_embedding, run_trunk, run_trunk_single_pass, run_structure_module, run_heads, run_pipeline. Also adds esm2_forward_ad (AD-compatible ESM2 forward) and named constants (DISTOGRAM_BINS, LDDT_BINS, NUM_ATOM_TYPES, RECYCLE_DISTANCE_BINS). run_pipeline() produces bit-identical output to infer() — verified on GPU with 0.0 max diff across all 25 output keys. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.