Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions gpytorch/models/exact_prediction_strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3

import functools
import string
import warnings

Expand All @@ -25,7 +24,7 @@
from .. import settings
from ..distributions import MultitaskMultivariateNormal
from ..lazy import LazyEvaluatedKernelTensor
from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache
from ..utils.memoize import add_to_cache, cached, pop_from_cache, register_cache_clear_hook


def prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood):
Expand Down Expand Up @@ -108,10 +107,7 @@ def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root
if settings.detach_test_caches.on():
res = res.detach()

if res.grad_fn is not None:
wrapper = functools.partial(clear_cache_hook, self)
functools.update_wrapper(wrapper, clear_cache_hook)
res.grad_fn.register_hook(wrapper)
register_cache_clear_hook(res, self)

return res

Expand Down Expand Up @@ -313,10 +309,7 @@ def _mean_cache(self, nan_policy: str) -> Tensor:
if settings.detach_test_caches.on():
mean_cache = mean_cache.detach()

if mean_cache.grad_fn is not None:
wrapper = functools.partial(clear_cache_hook, self)
functools.update_wrapper(wrapper, clear_cache_hook)
mean_cache.grad_fn.register_hook(wrapper)
register_cache_clear_hook(mean_cache, self)

return mean_cache

Expand Down
19 changes: 19 additions & 0 deletions gpytorch/utils/memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import functools
import pickle
import weakref

from .errors import CachingError

Expand Down Expand Up @@ -46,6 +47,24 @@ def clear_cache_hook(module, *args, **kwargs):
module._memoize_cache = {}


def register_cache_clear_hook(tsr, module):
"""Register a backward hook on tsr's grad_fn that clears module's cache.

Uses a weak reference to module to avoid creating an uncollectable
reference cycle through the C++ grad_fn object (which Python's cycle
GC cannot see through).
"""
if tsr.grad_fn is not None:
weak_module = weakref.ref(module)

def hook(*args, **kwargs):
obj = weak_module()
if obj is not None:
obj._memoize_cache = {}

tsr.grad_fn.register_hook(hook)


def _cached(method=None, name=None):
"""A decorator allowing for specifying the name of a cache, allowing it to be modified elsewhere.
This variant honors the calling args to the decorated function.
Expand Down
8 changes: 2 additions & 6 deletions gpytorch/variational/_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import functools
from abc import ABC, abstractproperty
from copy import deepcopy

Expand All @@ -18,7 +17,7 @@
from ..models import ApproximateGP, ExactGP
from ..models.exact_prediction_strategies import DefaultPredictionStrategy
from ..module import Module
from ..utils.memoize import add_to_cache, cached, clear_cache_hook
from ..utils.memoize import add_to_cache, cached, clear_cache_hook, register_cache_clear_hook
from . import _VariationalDistribution


Expand All @@ -42,10 +41,7 @@ def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:


def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor:
if tsr.grad_fn is not None:
wrapper = functools.partial(clear_cache_hook, pred_strat)
functools.update_wrapper(wrapper, clear_cache_hook)
tsr.grad_fn.register_hook(wrapper)
register_cache_clear_hook(tsr, pred_strat)
return tsr


Expand Down
Loading