Skip to content

LargeBatchVariationalStrategy caches triangular solves for test time#2698

Merged
jacobrgardner merged 5 commits intomainfrom
svgp_caches
Jan 16, 2026
Merged

LargeBatchVariationalStrategy caches triangular solves for test time#2698
jacobrgardner merged 5 commits intomainfrom
svgp_caches

Conversation

@jacobrgardner
Copy link
Member

@jacobrgardner jacobrgardner commented Jan 15, 2026

The new LargeBatchVariationalStrategy computes triangular solves against the inducing_values and to compute the middle_term K_zz^{-1/2}(S - I)K_{zz}^{-1/2}. Since neither of these depend on data, we can cache them in eval mode, circumventing the need for linear solves entirely at test time for both the mean and the variance.

This also simplifies onnx exporting relative to the original variational strategy (since onnx does not directly support triangular_solve), so I added a unit test to cover easy onnx exporting of LargeBatchVariationalStrategy GP models.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds inference-time caching for triangular solves in LargeBatchVariationalStrategy to improve performance during test/evaluation mode. The caching mechanism stores intermediate computation results that don't depend on input data, eliminating redundant calculations during repeated inference calls.

Changes:

  • Implemented caching logic for inv_chol_t_inducing_values and middle_term during eval mode in _compute_predictive_updates
  • Added _clear_cache method to properly clean up inference caches when switching to training mode
  • Added comprehensive tests for caching behavior and ONNX export support with fp64 precision
  • Updated CI workflow to install ONNX dependencies for testing

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
gpytorch/variational/large_batch_variational_strategy.py Implements caching mechanism for inference-time optimization and cache cleanup
test/variational/test_large_batch_variational_strategy.py Adds tests for caching behavior and ONNX export functionality
.github/workflows/run_test_suite.yml Adds ONNX-related dependencies to test environment

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@kayween kayween self-assigned this Jan 15, 2026
Comment on lines +251 to +271
def softplus_symbolic(g, self, beta, threshold):
# Numerically stable Softplus using Where:
# softplus(x) = x + log(1 + exp(-x)) for x > 0
# = log(1 + exp(x)) for x <= 0
scaled = g.op("Mul", self, beta)
zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64))
one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64))
condition = g.op("Greater", scaled, zero)

neg_scaled = g.op("Neg", scaled)
exp_neg = g.op("Exp", neg_scaled)
one_plus_exp_neg = g.op("Add", one, exp_neg)
log_pos = g.op("Log", one_plus_exp_neg)
result_pos = g.op("Add", scaled, log_pos)

exp_scaled = g.op("Exp", scaled)
one_plus_exp = g.op("Add", one, exp_scaled)
result_neg = g.op("Log", one_plus_exp)

stable_result = g.op("Where", condition, result_pos, result_neg)
return g.op("Div", stable_result, beta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not super familiar with ONNX. Does ONNX not have softplus so that we need to implement a custom one here? Similarly, is the transpose function necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has softplus, but only an fp32 version (for whatever reason).

It supports transpose, but not .mT specifically which is used all over the place in linear_operator and would be a much bigger fix to replace everywhere there than here.

Comment on lines +256 to +257
zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64))
one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe use torch.tensor(0.0, dtype=torch.float64) so that it does not change the shape if x is a scalar.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is just more onnx idiosyncrasies unfortunately.

Comment on lines +94 to +95
if not self.training:
self._cached_inv_chol_t_inducing_values = inv_chol_t_inducing_values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would still trigger a triangular solve once at test-time. I am not sure if this is a concern?

If we use GPyTorch's cached decorator like this, we could avoid test-time triangular solves completely. It's probably a nit because we can do a forward pass before ONNX exporting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we could make it decorated and then just add a grad hook. However, the current behavior is generally in line with how other caching in GPyTorch works (e.g. exact GPs cache the first forward through in eval mode).

@kayween
Copy link
Collaborator

kayween commented Jan 15, 2026

Looks good to me. I left some inline comments. I am not super familiar with ONNX, but it looks sensible since and the tests has passed.

@jacobrgardner Can you add a short description in the PR summary?

@jacobrgardner jacobrgardner merged commit 1726773 into main Jan 16, 2026
7 checks passed
@jacobrgardner jacobrgardner deleted the svgp_caches branch January 16, 2026 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants