Skip to content

Commit 24ee799

Browse files
esantorellafacebook-github-bot
authored andcommitted
Make behavior of Cholesky caching more clear (#1504)
Summary: ## Motivation General context: Caching is confusing and can lead to subtle issues. I have been trying to understand it better in order to reduce memory usage and improve runtime, since I've been seeing cache misses and tensors persisting longer than necessary. This PR doesn't fix that, but does make things a tiny bit more transparent. Two things are making the "_cache_root_decomposition" method harder to understand than necessary: 1. It sets `self._baseline_L` and returns `None` rather than just returning `baseline_L`, so when someone sees a call to `_cache_root_decomposition` they will not immediately realize `self._baseline_L` has been set. 2. It is uses two different kinds of caching: it sets `self._baseline_L` and it also invisibly uses LinearOperator's caching. This PR makes things more transparent by * Adding comments * Returning `baseline_L` rather than setting it as a side effect ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1504 Test Plan: Unit tests Reviewed By: Balandat Differential Revision: D41308840 Pulled By: esantorella fbshipit-source-id: 0a4c9e331923c3a6b8cc8212e5f605f1c65b9901
1 parent 8eea8e1 commit 24ee799

File tree

5 files changed

+41
-10
lines changed

5 files changed

+41
-10
lines changed

botorch/acquisition/cached_cholesky.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,24 @@ def _setup(
9797
cache_root = False
9898
self._cache_root = cache_root
9999

100-
def _cache_root_decomposition(
100+
def _compute_root_decomposition(
101101
self,
102102
posterior: Posterior,
103-
) -> None:
103+
) -> Tensor:
104104
r"""Cache Cholesky of the posterior covariance over f(X_baseline).
105105
106+
Because `LinearOperator.root_decomposition` is decorated with LinearOperator's
107+
@cached decorator, this function is doing a lot implicitly:
108+
109+
1) Check if a root decomposition has already been cached to `lazy_covar`.
110+
Note that it will not have been if `posterior.mvn` is a
111+
`MultitaskMultivariateNormal`, since we construct `lazy_covar` in that
112+
case.
113+
2) If the root decomposition has not been found in the cache, compute it.
114+
3) Write it to the cache of `lazy_covar`. Note that this will become inacessible
115+
if `posterior.mvn` is a `MultitaskMultivariateNormal`, since in that case
116+
`lazy_covar`'s scope is only this function.
117+
106118
Args:
107119
posterior: The posterior over f(X_baseline).
108120
"""
@@ -112,8 +124,7 @@ def _cache_root_decomposition(
112124
lazy_covar = posterior.mvn.lazy_covariance_matrix
113125
with gpt_settings.fast_computations.covar_root_decomposition(False):
114126
lazy_covar_root = lazy_covar.root_decomposition()
115-
baseline_L = lazy_covar_root.root.to_dense()
116-
self.register_buffer("_baseline_L", baseline_L)
127+
return lazy_covar_root.root.to_dense()
117128

118129
def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
119130
r"""Get posterior samples at the `q_in` new points from the joint posterior.

botorch/acquisition/monte_carlo.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ def __init__(
254254
X_pending=X_pending,
255255
)
256256
self._setup(model=model, sampler=self.sampler, cache_root=cache_root)
257+
# We make a copy here because we will write an attribute `base_samples`
258+
# to `self.base_sampler.base_samples`, and we don't want to mutate
259+
# `self.sampler`.
257260
self.base_sampler = deepcopy(self.sampler)
258261
if prune_baseline:
259262
X_baseline = prune_inferior_points(
@@ -272,13 +275,22 @@ def __init__(
272275
posterior = self.model.posterior(
273276
X_baseline, posterior_transform=self.posterior_transform
274277
)
278+
# Note: The root decomposition is cached in two different places. It
279+
# may be confusing to have two different caches, but this is not
280+
# trivial to change since each is needed for a different reason:
281+
# - LinearOperator caching to `posterior.mvn` allows for reuse within
282+
# this function, which may be helpful if the same root decomposition
283+
# is produced by the calls to `self.base_sampler` and
284+
# `self._cache_root_decomposition`.
285+
# - self._baseline_L allows a root decomposition to be persisted outside
286+
# this method.
275287
baseline_samples = self.base_sampler(posterior)
276288
baseline_obj = self.objective(baseline_samples, X=X_baseline)
277289
self.register_buffer("baseline_samples", baseline_samples)
278290
self.register_buffer(
279291
"baseline_obj_max_values", baseline_obj.max(dim=-1).values
280292
)
281-
self._cache_root_decomposition(posterior=posterior)
293+
self._baseline_L = self._compute_root_decomposition(posterior=posterior)
282294

283295
def _set_sampler(
284296
self,

botorch/acquisition/multi_objective/monte_carlo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,11 @@ def _set_cell_bounds(self, num_new_points: int) -> None:
545545
samples = self.base_sampler(posterior)
546546
# cache posterior
547547
if self._cache_root:
548-
self._cache_root_decomposition(posterior=posterior)
548+
# Note that this implicitly uses LinearOperator's caching to check if
549+
# the proper root decomposition has already been cached to
550+
# `posterior.mvn.lazy_covariance_matrix`, which it may have been in
551+
# the call to `self.base_sampler`, and computes it if not found
552+
self._baseline_L = self._compute_root_decomposition(posterior=posterior)
549553
obj = self.objective(samples, X=self.X_baseline)
550554
if self.constraints is not None:
551555
feas = torch.stack(

test/acquisition/test_cached_cholesky.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def test_cache_root_decomposition(self):
110110
with mock.patch(
111111
CHOLESKY_PATH, return_value=baseline_L
112112
) as mock_cholesky:
113-
acqf._cache_root_decomposition(posterior=posterior)
113+
baseline_L_acqf = acqf._compute_root_decomposition(
114+
posterior=posterior
115+
)
114116
mock_extract_batch_covar.assert_called_once_with(posterior.mvn)
115117
mock_cholesky.assert_called_once()
116118
# test mvn
@@ -121,10 +123,12 @@ def test_cache_root_decomposition(self):
121123
with mock.patch(
122124
CHOLESKY_PATH, return_value=baseline_L
123125
) as mock_cholesky:
124-
acqf._cache_root_decomposition(posterior=posterior)
126+
baseline_L_acqf = acqf._compute_root_decomposition(
127+
posterior=posterior
128+
)
125129
mock_extract_batch_covar.assert_not_called()
126130
mock_cholesky.assert_called_once()
127-
self.assertTrue(torch.equal(acqf._baseline_L, baseline_L))
131+
self.assertTrue(torch.equal(baseline_L_acqf, baseline_L))
128132

129133
def test_get_f_X_samples(self):
130134
tkwargs = {"device": self.device}

test/acquisition/test_monte_carlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def test_cache_root(self):
577577
pt = ScalarizedPosteriorTransform(weights=torch.tensor([-1]))
578578
with mock.patch.object(
579579
qNoisyExpectedImprovement,
580-
"_cache_root_decomposition",
580+
"_compute_root_decomposition",
581581
) as mock_cache_root:
582582
acqf = qNoisyExpectedImprovement(
583583
model=model,

0 commit comments

Comments
 (0)