perf: reuse AF2 trajectory runtime across binder hallucination trajectories#359
perf: reuse AF2 trajectory runtime across binder hallucination trajectories#359mahdip72 wants to merge 6 commits intomartinpacesa:mainfrom
Conversation
- Add invalidate() method to reset cached model after clear_mem() - Zero out loss callback weights when settings are disabled - Document clear_mem() interaction in _ensure_model() docstring - Call invalidate() after mpnn_gen_sequence in bindcraft.py - Update notebook to create/pass TrajectoryDesignRuntime and invalidate Co-authored-by: mahdip72 <42680708+mahdip72@users.noreply.github.com>
Co-authored-by: mahdip72 <42680708+mahdip72@users.noreply.github.com>
fix: harden TrajectoryDesignRuntime against clear_mem() invalidation and stale loss callbacks
…repo perf: reuse AF2 trajectory runtime across binder hallucination trajectories
There was a problem hiding this comment.
Pull request overview
This PR reduces per-trajectory overhead in BindCraft by introducing a reusable AlphaFold2 (ColabDesign) “trajectory runtime” so binder hallucination can reuse a compiled AF2 design model across the trajectory loop.
Changes:
- Added
TrajectoryDesignRuntimeto cache/rebuild the AF2 design model on demand and integrated it intobinder_hallucination(...). - Updated
bindcraft.pyand the Colab notebook to initialize and pass the reusable runtime, and to invalidate it after MPNN generation. - Added a
.gitignorefor Python bytecode artifacts.
Reviewed changes
Copilot reviewed 3 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
functions/colabdesign_utils.py |
Introduces TrajectoryDesignRuntime and updates binder_hallucination to optionally reuse a cached AF2 model. |
bindcraft.py |
Creates a shared trajectory runtime once and passes it through binder hallucination; invalidates after MPNN generation. |
notebooks/BindCraft.ipynb |
Mirrors the runtime reuse/invalidation changes in the Colab workflow. |
.gitignore |
Ignores __pycache__/ and *.pyc. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if runtime is None: | ||
| if _DEFAULT_TRAJECTORY_RUNTIME is None: | ||
| _DEFAULT_TRAJECTORY_RUNTIME = TrajectoryDesignRuntime(advanced_settings) | ||
| runtime = _DEFAULT_TRAJECTORY_RUNTIME |
There was a problem hiding this comment.
The module-level _DEFAULT_TRAJECTORY_RUNTIME caches a single TrajectoryDesignRuntime instance across calls. If binder_hallucination() is invoked later in the same Python process with a different advanced_settings (e.g., running multiple campaigns in one session), the cached runtime will silently reuse the old settings/model, which can produce incorrect behavior and also keep GPU memory alive longer than intended. Prefer creating a new runtime per call when runtime is not provided, or key the cache by a stable identifier (e.g., settings path/hash) and provide an explicit reset/close API.
| mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings) | ||
| trajectory_runtime.invalidate() |
There was a problem hiding this comment.
mpnn_gen_sequence() calls clear_mem() internally, which can free the JAX buffers backing trajectory_runtime.af_model while trajectory_runtime still holds a reference to it. To avoid a transient invalid cached state (and make failure modes safer), consider calling trajectory_runtime.invalidate() immediately before invoking mpnn_gen_sequence() (or wrap the call in a try/finally that invalidates even if MPNN generation errors).
| mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings) | |
| trajectory_runtime.invalidate() | |
| trajectory_runtime.invalidate() | |
| try: | |
| mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings) | |
| finally: | |
| trajectory_runtime.invalidate() |
| " mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n", | ||
| " trajectory_runtime.invalidate()\n", |
There was a problem hiding this comment.
This notebook invalidates the cached trajectory runtime after mpnn_gen_sequence(), but mpnn_gen_sequence() starts by calling clear_mem(), which may free the JAX buffers of the cached AF2 model while trajectory_runtime still references it. To avoid a transient invalid cached state (and make errors easier to recover from in an interactive session), invalidate the runtime before calling mpnn_gen_sequence() or use a try/finally around the call.
| " mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n", | |
| " trajectory_runtime.invalidate()\n", | |
| " trajectory_runtime.invalidate()\n", | |
| " mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n", |
| " os.system(\"(mkdir bindcraft/params; apt-get install aria2 -qq; \\\n", | ||
| " aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \\\n", | ||
| " tar -xf alphafold_params_2022-12-06.tar -C bindcraft/params; touch bindcraft/params/done.txt )&\")\n", | ||
| " os.system(\"pip install git+https://github.com/sokrypton/ColabDesign.git\")\n", |
There was a problem hiding this comment.
The installation cell installs ColabDesign directly from the mutable GitHub HEAD via pip install git+https://github.com/sokrypton/ColabDesign.git, which creates a supply-chain risk if that repository is compromised or a malicious commit is pushed. In a Colab environment with an attached Google Drive, a compromised dependency could exfiltrate data or run arbitrary code with the notebook’s privileges. To mitigate this, pin the dependency to a specific commit or tagged release and, where possible, verify its integrity (e.g., via hashes or a vetted package index) instead of installing from an unpinned VCS URL.
BindCraft currently rebuilds the AF2 design model on every trajectory (
clear_mem()+mk_afdesign_model(...)), which adds repeated setup/compile overhead and introduces idle gaps between trajectories.This PR reuses a single AF2 trajectory model instance across the trajectory loop (same workflow/outputs, lower steady-state runtime).
What changed
functions/colabdesign_utils.py(TrajectoryDesignRuntime).bindcraft.pynow initializes it once and passes it tobinder_hallucination(...).mpnn_gen_sequence(...)callsclear_mem()trajectory_runtime.invalidate()ensures AF2 model is safely rebuilt on the next trajectory.advanced_settingsiteration values.What did NOT change
Correctness verification
Baseline vs optimized comparison across:
2stage,3stage,greedy,mcmc,4stage(fixed seeds/settings)Matched behavior on:
Performance (A6000 Ada 48GB, 4stage, fixed seeds)
Overall:
Note
I closed the previous PR and opened this fresh PR to keep review scope clean.
@martinpacesa