Skip to content

Fix memory leak in DefaultPredictionStrategy cache hooks (#2631)#2734

Open
saitcakmak wants to merge 2 commits intocornellius-gp:mainfrom
saitcakmak:fix-memory-leak-mean-cache
Open

Fix memory leak in DefaultPredictionStrategy cache hooks (#2631)#2734
saitcakmak wants to merge 2 commits intocornellius-gp:mainfrom
saitcakmak:fix-memory-leak-mean-cache

Conversation

@saitcakmak
Copy link
Collaborator

Fixes #2631 and meta-pytorch/botorch#2728

Root cause

When detach_test_caches is False (as used by BoTorch via propagate_grads(True)), the _mean_cache and
_exact_predictive_covar_inv_quad_form_cache properties register a clear_cache_hook on the cached tensor's grad_fn:

wrapper = functools.partial(clear_cache_hook, self)
mean_cache.grad_fn.register_hook(wrapper)

This creates a reference cycle:

prediction_strategy -> _memoize_cache -> cached tensor
    -> grad_fn (C++ object) -> hook closure -> prediction_strategy

Because PyTorch's grad_fn is a C++ object, Python's cycle garbage collector cannot traverse the C++/Python boundary to detect the cycle. The entire chain becomes uncollectable. For fantasy models, the cached tensor's computation graph holds references back to the parent model's caches, so each iteration adds another uncollectable chain, causing memory to grow indefinitely until OOM.

Fix

Added register_cache_clear_hook(tsr, module) to gpytorch/utils/memoize.py. This helper uses weakref.ref(module) when registering the backward hook on tsr.grad_fn, breaking the reference cycle. When no external strong references to the prediction strategy remain, it and its caches are garbage-collected normally, and the hook becomes a no-op.

Updated all three call sites to use the new helper:

  • DefaultPredictionStrategy._exact_predictive_covar_inv_quad_form_cache
  • DefaultPredictionStrategy._mean_cache
  • _variational_strategy._add_cache_hook

Validation (ran manually)

Repro script (200 iterations of fantasy model creation + evaluation with detach_test_caches(False)):

import tracemalloc

import torch
from gpytorch import settings as gpt_settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP

d = 10
mc_points = torch.rand(32, d, dtype=torch.double)


class SimpleGP(ExactGP):
    def __init__(self, train_inputs, train_targets):
        super().__init__(train_inputs, train_targets, GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = RBFKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


gp = SimpleGP(
    train_inputs=torch.rand(256, d, dtype=torch.double),
    train_targets=torch.rand(256, dtype=torch.double),
).eval()
gp(torch.rand(5, d, dtype=torch.double))  # set the caches before fantasize.

X = torch.rand(128, 5, d, dtype=torch.double, requires_grad=False)
Y = torch.rand(128, 5, dtype=torch.double, requires_grad=False)

tracemalloc.start()

for i in range(1000):
    fantasy_model = gp.get_fantasy_model(inputs=X, targets=Y).eval()
    with gpt_settings.detach_test_caches(False):
        fantasy_model(mc_points)
    if (i + 1) % 100 == 0:
        current, peak = tracemalloc.get_traced_memory()
        print(f"Iter {i + 1}: current={current / 1024 / 1024:.1f}MB, peak={peak / 1024 / 1024:.1f}MB")

Before: OOM with 48GB RAM.

(botorch) saitcakmak@saitcakmak-mac gpytorch % git branch
  fix-memory-leak-mean-cache
* main
(botorch) saitcakmak@saitcakmak-mac gpytorch % python test.py
Iter 100: current=3.6MB, peak=3.6MB
Iter 200: current=6.7MB, peak=6.8MB
zsh: killed     python test.py

After:

(botorch) saitcakmak@saitcakmak-mac gpytorch % git branch
* fix-memory-leak-mean-cache
  main
(botorch) saitcakmak@saitcakmak-mac gpytorch % python test.py
Iter 100: current=1.1MB, peak=1.8MB
Iter 200: current=1.4MB, peak=2.0MB
Iter 300: current=1.6MB, peak=2.3MB
Iter 400: current=0.4MB, peak=2.6MB
Iter 500: current=0.9MB, peak=2.6MB
Iter 600: current=1.3MB, peak=2.6MB
Iter 700: current=1.5MB, peak=2.6MB
Iter 800: current=2.6MB, peak=2.6MB
Iter 900: current=2.8MB, peak=2.8MB
Iter 1000: current=3.0MB, peak=3.1MB

…gp#2631)

Root cause
----------
When `detach_test_caches` is False (as used by BoTorch via
`propagate_grads(True)`), the `_mean_cache` and
`_exact_predictive_covar_inv_quad_form_cache` properties register a
`clear_cache_hook` on the cached tensor's `grad_fn`:

    wrapper = functools.partial(clear_cache_hook, self)
    mean_cache.grad_fn.register_hook(wrapper)

This creates a reference cycle:

    prediction_strategy -> _memoize_cache -> cached tensor
        -> grad_fn (C++ object) -> hook closure -> prediction_strategy

Because PyTorch's `grad_fn` is a C++ object, Python's cycle garbage
collector cannot traverse the C++/Python boundary to detect the cycle.
The entire chain becomes uncollectable. For fantasy models, the cached
tensor's computation graph holds references back to the parent model's
caches, so each iteration adds another uncollectable chain, causing
memory to grow indefinitely until OOM.

Fix
---
Added `register_cache_clear_hook(tsr, module)` to `gpytorch/utils/memoize.py`.
This helper uses `weakref.ref(module)` when registering the backward hook on
`tsr.grad_fn`, breaking the reference cycle. When no external strong references
to the prediction strategy remain, it and its caches are garbage-collected
normally, and the hook becomes a no-op.

Updated all three call sites to use the new helper:
- `DefaultPredictionStrategy._exact_predictive_covar_inv_quad_form_cache`
- `DefaultPredictionStrategy._mean_cache`
- `_variational_strategy._add_cache_hook`

Validation
----------
Repro script (200 iterations of fantasy model creation + evaluation
with `detach_test_caches(False)`):

    import torch, tracemalloc
    from gpytorch import settings as gpt_settings
    from gpytorch.distributions import MultivariateNormal
    from gpytorch.kernels import RBFKernel
    from gpytorch.likelihoods import GaussianLikelihood
    from gpytorch.means import ConstantMean
    from gpytorch.models import ExactGP

    d = 10

    class SimpleGP(ExactGP):
        def __init__(self, train_inputs, train_targets):
            super().__init__(train_inputs, train_targets, GaussianLikelihood())
            self.mean_module = ConstantMean()
            self.covar_module = RBFKernel()
        def forward(self, x):
            return MultivariateNormal(self.mean_module(x), self.covar_module(x))

    gp = SimpleGP(
        train_inputs=torch.rand(256, d, dtype=torch.double),
        train_targets=torch.rand(256, dtype=torch.double),
    ).eval()
    gp(torch.rand(5, d, dtype=torch.double))

    X = torch.rand(128, 5, d, dtype=torch.double)
    Y = torch.rand(128, 5, dtype=torch.double)

    tracemalloc.start()
    for i in range(200):
        fantasy_model = gp.get_fantasy_model(inputs=X, targets=Y).eval()
        with gpt_settings.detach_test_caches(False):
            fantasy_model(torch.rand(32, d, dtype=torch.double))
        if (i + 1) % 50 == 0:
            current, peak = tracemalloc.get_traced_memory()
            print(f"Iter {i+1}: current={current/1024/1024:.1f}MB, "
                  f"peak={peak/1024/1024:.1f}MB")

Results before fix: memory grows ~50MB+ per 50 iterations, OOM by ~200.
Results after fix:  memory stays flat at ~1-2 MB across all 200 iterations.

Gradient propagation through fantasy model predictions with
`detach_test_caches(False)` was also verified to still work correctly.
All 59 existing tests in test_exact_gp.py and
test_derivative_gp_fantasy.py pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Memory leak in fantasy models when the mean cache is not detached

1 participant