Remove inference multiplicity #36
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #24
Description
Running inference with
--nsamples_per_proteingreater than 1 blows up memory using the torch backend. This happens becauseinference_multiplicitywas used to replicate every tensor (including ESM embeddings)nsample_per_proteintimes before moving them to the GPU, so multi-sample inference pushed huge duplicated feature tensors into vRAM and regularly OOM’d.After this fix, taking 5 samples of a protein sequence of ~1000 amino acids stays in around 31GiB of vRAM using the torch backend (instead of running OOM for a 81GiB GPU.)
Implementation
I fixed it by removing
inference_multiplicityaltogether to simply run inference in ansample_per_proteinloop.This does slow down inference since it's not batched, so feel free to close the PR, just here for future reference.