Skip to content

perf: reuse AF2 trajectory runtime across binder hallucination trajectories#359

Open
mahdip72 wants to merge 6 commits intomartinpacesa:mainfrom
mahdip72:main
Open

perf: reuse AF2 trajectory runtime across binder hallucination trajectories#359
mahdip72 wants to merge 6 commits intomartinpacesa:mainfrom
mahdip72:main

Conversation

@mahdip72
Copy link

BindCraft currently rebuilds the AF2 design model on every trajectory (clear_mem() + mk_afdesign_model(...)), which adds repeated setup/compile overhead and introduces idle gaps between trajectories.

This PR reuses a single AF2 trajectory model instance across the trajectory loop (same workflow/outputs, lower steady-state runtime).

What changed

  • Added reusable trajectory runtime in functions/colabdesign_utils.py (TrajectoryDesignRuntime).
  • bindcraft.py now initializes it once and passes it to binder_hallucination(...).
  • Added explicit runtime invalidation after MPNN generation:
    • mpnn_gen_sequence(...) calls clear_mem()
    • then trajectory_runtime.invalidate() ensures AF2 model is safely rebuilt on the next trajectory.
  • Fixed 4-stage state leak by avoiding in-place mutation of shared advanced_settings iteration values.

What did NOT change

  • No CLI/config schema changes.
  • No output format / CSV schema changes.
  • No acceptance/filtering logic changes.
  • No MPNN validation-path redesign.

Correctness verification

Baseline vs optimized comparison across:
2stage, 3stage, greedy, mcmc, 4stage (fixed seeds/settings)

Matched behavior on:

  • terminate status
  • designed sequence
  • key metrics (within tiny float tolerance)

Performance (A6000 Ada 48GB, 4stage, fixed seeds)

Seed Baseline (ms) Optimized (ms) Speedup
420761 162955.38 158741.67 1.027x
420762 162895.42 27245.98 5.979x
420763 168697.97 27216.95 6.198x
420764 168267.46 25822.96 6.516x

Overall:

  • first run (cold compile): similar
  • steady-state runs: ~6.2x faster
  • total 4-run batch: ~2.77x faster

Note

I closed the previous PR and opened this fresh PR to keep review scope clean.

@martinpacesa

mahdip72 and others added 6 commits February 27, 2026 15:57
- Add invalidate() method to reset cached model after clear_mem()
- Zero out loss callback weights when settings are disabled
- Document clear_mem() interaction in _ensure_model() docstring
- Call invalidate() after mpnn_gen_sequence in bindcraft.py
- Update notebook to create/pass TrajectoryDesignRuntime and invalidate

Co-authored-by: mahdip72 <42680708+mahdip72@users.noreply.github.com>
Co-authored-by: mahdip72 <42680708+mahdip72@users.noreply.github.com>
fix: harden TrajectoryDesignRuntime against clear_mem() invalidation and stale loss callbacks
…repo

perf: reuse AF2 trajectory runtime across binder hallucination trajectories
Copilot AI review requested due to automatic review settings February 27, 2026 23:18
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR reduces per-trajectory overhead in BindCraft by introducing a reusable AlphaFold2 (ColabDesign) “trajectory runtime” so binder hallucination can reuse a compiled AF2 design model across the trajectory loop.

Changes:

  • Added TrajectoryDesignRuntime to cache/rebuild the AF2 design model on demand and integrated it into binder_hallucination(...).
  • Updated bindcraft.py and the Colab notebook to initialize and pass the reusable runtime, and to invalidate it after MPNN generation.
  • Added a .gitignore for Python bytecode artifacts.

Reviewed changes

Copilot reviewed 3 out of 4 changed files in this pull request and generated 4 comments.

File Description
functions/colabdesign_utils.py Introduces TrajectoryDesignRuntime and updates binder_hallucination to optionally reuse a cached AF2 model.
bindcraft.py Creates a shared trajectory runtime once and passes it through binder hallucination; invalidates after MPNN generation.
notebooks/BindCraft.ipynb Mirrors the runtime reuse/invalidation changes in the Colab workflow.
.gitignore Ignores __pycache__/ and *.pyc.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +183 to +186
if runtime is None:
if _DEFAULT_TRAJECTORY_RUNTIME is None:
_DEFAULT_TRAJECTORY_RUNTIME = TrajectoryDesignRuntime(advanced_settings)
runtime = _DEFAULT_TRAJECTORY_RUNTIME
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module-level _DEFAULT_TRAJECTORY_RUNTIME caches a single TrajectoryDesignRuntime instance across calls. If binder_hallucination() is invoked later in the same Python process with a different advanced_settings (e.g., running multiple campaigns in one session), the cached runtime will silently reuse the old settings/model, which can produce incorrect behavior and also keep GPU memory alive longer than intended. Prefer creating a new runtime per call when runtime is not provided, or key the cache by a stable identifier (e.g., settings path/hash) and provide an explicit reset/close API.

Copilot uses AI. Check for mistakes.
Comment on lines 178 to +179
mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)
trajectory_runtime.invalidate()
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mpnn_gen_sequence() calls clear_mem() internally, which can free the JAX buffers backing trajectory_runtime.af_model while trajectory_runtime still holds a reference to it. To avoid a transient invalid cached state (and make failure modes safer), consider calling trajectory_runtime.invalidate() immediately before invoking mpnn_gen_sequence() (or wrap the call in a try/finally that invalidates even if MPNN generation errors).

Suggested change
mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)
trajectory_runtime.invalidate()
trajectory_runtime.invalidate()
try:
mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)
finally:
trajectory_runtime.invalidate()

Copilot uses AI. Check for mistakes.
Comment on lines +458 to +459
" mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n",
" trajectory_runtime.invalidate()\n",
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This notebook invalidates the cached trajectory runtime after mpnn_gen_sequence(), but mpnn_gen_sequence() starts by calling clear_mem(), which may free the JAX buffers of the cached AF2 model while trajectory_runtime still references it. To avoid a transient invalid cached state (and make errors easier to recover from in an interactive session), invalidate the runtime before calling mpnn_gen_sequence() or use a try/finally around the call.

Suggested change
" mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n",
" trajectory_runtime.invalidate()\n",
" trajectory_runtime.invalidate()\n",
" mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n",

Copilot uses AI. Check for mistakes.
" os.system(\"(mkdir bindcraft/params; apt-get install aria2 -qq; \\\n",
" aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \\\n",
" tar -xf alphafold_params_2022-12-06.tar -C bindcraft/params; touch bindcraft/params/done.txt )&\")\n",
" os.system(\"pip install git+https://github.com/sokrypton/ColabDesign.git\")\n",
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The installation cell installs ColabDesign directly from the mutable GitHub HEAD via pip install git+https://github.com/sokrypton/ColabDesign.git, which creates a supply-chain risk if that repository is compromised or a malicious commit is pushed. In a Colab environment with an attached Google Drive, a compromised dependency could exfiltrate data or run arbitrary code with the notebook’s privileges. To mitigate this, pin the dependency to a specific commit or tagged release and, where possible, verify its integrity (e.g., via hashes or a vetted package index) instead of installing from an unpinned VCS URL.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants