Fix memory leak in DefaultPredictionStrategy cache hooks (#2631)#2734
Open
saitcakmak wants to merge 2 commits intocornellius-gp:mainfrom
Open
Fix memory leak in DefaultPredictionStrategy cache hooks (#2631)#2734saitcakmak wants to merge 2 commits intocornellius-gp:mainfrom
saitcakmak wants to merge 2 commits intocornellius-gp:mainfrom
Conversation
…gp#2631) Root cause ---------- When `detach_test_caches` is False (as used by BoTorch via `propagate_grads(True)`), the `_mean_cache` and `_exact_predictive_covar_inv_quad_form_cache` properties register a `clear_cache_hook` on the cached tensor's `grad_fn`: wrapper = functools.partial(clear_cache_hook, self) mean_cache.grad_fn.register_hook(wrapper) This creates a reference cycle: prediction_strategy -> _memoize_cache -> cached tensor -> grad_fn (C++ object) -> hook closure -> prediction_strategy Because PyTorch's `grad_fn` is a C++ object, Python's cycle garbage collector cannot traverse the C++/Python boundary to detect the cycle. The entire chain becomes uncollectable. For fantasy models, the cached tensor's computation graph holds references back to the parent model's caches, so each iteration adds another uncollectable chain, causing memory to grow indefinitely until OOM. Fix --- Added `register_cache_clear_hook(tsr, module)` to `gpytorch/utils/memoize.py`. This helper uses `weakref.ref(module)` when registering the backward hook on `tsr.grad_fn`, breaking the reference cycle. When no external strong references to the prediction strategy remain, it and its caches are garbage-collected normally, and the hook becomes a no-op. Updated all three call sites to use the new helper: - `DefaultPredictionStrategy._exact_predictive_covar_inv_quad_form_cache` - `DefaultPredictionStrategy._mean_cache` - `_variational_strategy._add_cache_hook` Validation ---------- Repro script (200 iterations of fantasy model creation + evaluation with `detach_test_caches(False)`): import torch, tracemalloc from gpytorch import settings as gpt_settings from gpytorch.distributions import MultivariateNormal from gpytorch.kernels import RBFKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.models import ExactGP d = 10 class SimpleGP(ExactGP): def __init__(self, train_inputs, train_targets): super().__init__(train_inputs, train_targets, GaussianLikelihood()) self.mean_module = ConstantMean() self.covar_module = RBFKernel() def forward(self, x): return MultivariateNormal(self.mean_module(x), self.covar_module(x)) gp = SimpleGP( train_inputs=torch.rand(256, d, dtype=torch.double), train_targets=torch.rand(256, dtype=torch.double), ).eval() gp(torch.rand(5, d, dtype=torch.double)) X = torch.rand(128, 5, d, dtype=torch.double) Y = torch.rand(128, 5, dtype=torch.double) tracemalloc.start() for i in range(200): fantasy_model = gp.get_fantasy_model(inputs=X, targets=Y).eval() with gpt_settings.detach_test_caches(False): fantasy_model(torch.rand(32, d, dtype=torch.double)) if (i + 1) % 50 == 0: current, peak = tracemalloc.get_traced_memory() print(f"Iter {i+1}: current={current/1024/1024:.1f}MB, " f"peak={peak/1024/1024:.1f}MB") Results before fix: memory grows ~50MB+ per 50 iterations, OOM by ~200. Results after fix: memory stays flat at ~1-2 MB across all 200 iterations. Gradient propagation through fantasy model predictions with `detach_test_caches(False)` was also verified to still work correctly. All 59 existing tests in test_exact_gp.py and test_derivative_gp_fantasy.py pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #2631 and meta-pytorch/botorch#2728
Root cause
When
detach_test_cachesis False (as used by BoTorch viapropagate_grads(True)), the_mean_cacheand_exact_predictive_covar_inv_quad_form_cacheproperties register aclear_cache_hookon the cached tensor'sgrad_fn:This creates a reference cycle:
Because PyTorch's
grad_fnis a C++ object, Python's cycle garbage collector cannot traverse the C++/Python boundary to detect the cycle. The entire chain becomes uncollectable. For fantasy models, the cached tensor's computation graph holds references back to the parent model's caches, so each iteration adds another uncollectable chain, causing memory to grow indefinitely until OOM.Fix
Added
register_cache_clear_hook(tsr, module)togpytorch/utils/memoize.py. This helper usesweakref.ref(module)when registering the backward hook ontsr.grad_fn, breaking the reference cycle. When no external strong references to the prediction strategy remain, it and its caches are garbage-collected normally, and the hook becomes a no-op.Updated all three call sites to use the new helper:
DefaultPredictionStrategy._exact_predictive_covar_inv_quad_form_cacheDefaultPredictionStrategy._mean_cache_variational_strategy._add_cache_hookValidation (ran manually)
Repro script (200 iterations of fantasy model creation + evaluation with
detach_test_caches(False)):Before: OOM with 48GB RAM.
(botorch) saitcakmak@saitcakmak-mac gpytorch % git branch fix-memory-leak-mean-cache * main (botorch) saitcakmak@saitcakmak-mac gpytorch % python test.py Iter 100: current=3.6MB, peak=3.6MB Iter 200: current=6.7MB, peak=6.8MB zsh: killed python test.pyAfter:
(botorch) saitcakmak@saitcakmak-mac gpytorch % git branch * fix-memory-leak-mean-cache main (botorch) saitcakmak@saitcakmak-mac gpytorch % python test.py Iter 100: current=1.1MB, peak=1.8MB Iter 200: current=1.4MB, peak=2.0MB Iter 300: current=1.6MB, peak=2.3MB Iter 400: current=0.4MB, peak=2.6MB Iter 500: current=0.9MB, peak=2.6MB Iter 600: current=1.3MB, peak=2.6MB Iter 700: current=1.5MB, peak=2.6MB Iter 800: current=2.6MB, peak=2.6MB Iter 900: current=2.8MB, peak=2.8MB Iter 1000: current=3.0MB, peak=3.1MB