LargeBatchVariationalStrategy caches triangular solves for test time#2698
LargeBatchVariationalStrategy caches triangular solves for test time#2698jacobrgardner merged 5 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds inference-time caching for triangular solves in LargeBatchVariationalStrategy to improve performance during test/evaluation mode. The caching mechanism stores intermediate computation results that don't depend on input data, eliminating redundant calculations during repeated inference calls.
Changes:
- Implemented caching logic for
inv_chol_t_inducing_valuesandmiddle_termduring eval mode in_compute_predictive_updates - Added
_clear_cachemethod to properly clean up inference caches when switching to training mode - Added comprehensive tests for caching behavior and ONNX export support with fp64 precision
- Updated CI workflow to install ONNX dependencies for testing
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
gpytorch/variational/large_batch_variational_strategy.py |
Implements caching mechanism for inference-time optimization and cache cleanup |
test/variational/test_large_batch_variational_strategy.py |
Adds tests for caching behavior and ONNX export functionality |
.github/workflows/run_test_suite.yml |
Adds ONNX-related dependencies to test environment |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def softplus_symbolic(g, self, beta, threshold): | ||
| # Numerically stable Softplus using Where: | ||
| # softplus(x) = x + log(1 + exp(-x)) for x > 0 | ||
| # = log(1 + exp(x)) for x <= 0 | ||
| scaled = g.op("Mul", self, beta) | ||
| zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64)) | ||
| one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64)) | ||
| condition = g.op("Greater", scaled, zero) | ||
|
|
||
| neg_scaled = g.op("Neg", scaled) | ||
| exp_neg = g.op("Exp", neg_scaled) | ||
| one_plus_exp_neg = g.op("Add", one, exp_neg) | ||
| log_pos = g.op("Log", one_plus_exp_neg) | ||
| result_pos = g.op("Add", scaled, log_pos) | ||
|
|
||
| exp_scaled = g.op("Exp", scaled) | ||
| one_plus_exp = g.op("Add", one, exp_scaled) | ||
| result_neg = g.op("Log", one_plus_exp) | ||
|
|
||
| stable_result = g.op("Where", condition, result_pos, result_neg) | ||
| return g.op("Div", stable_result, beta) |
There was a problem hiding this comment.
I am not super familiar with ONNX. Does ONNX not have softplus so that we need to implement a custom one here? Similarly, is the transpose function necessary?
There was a problem hiding this comment.
It has softplus, but only an fp32 version (for whatever reason).
It supports transpose, but not .mT specifically which is used all over the place in linear_operator and would be a much bigger fix to replace everywhere there than here.
| zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64)) | ||
| one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64)) |
There was a problem hiding this comment.
nit: maybe use torch.tensor(0.0, dtype=torch.float64) so that it does not change the shape if x is a scalar.
There was a problem hiding this comment.
Yeah, this is just more onnx idiosyncrasies unfortunately.
| if not self.training: | ||
| self._cached_inv_chol_t_inducing_values = inv_chol_t_inducing_values |
There was a problem hiding this comment.
This would still trigger a triangular solve once at test-time. I am not sure if this is a concern?
If we use GPyTorch's cached decorator like this, we could avoid test-time triangular solves completely. It's probably a nit because we can do a forward pass before ONNX exporting.
There was a problem hiding this comment.
Yeah we could make it decorated and then just add a grad hook. However, the current behavior is generally in line with how other caching in GPyTorch works (e.g. exact GPs cache the first forward through in eval mode).
|
Looks good to me. I left some inline comments. I am not super familiar with ONNX, but it looks sensible since and the tests has passed. @jacobrgardner Can you add a short description in the PR summary? |
The new LargeBatchVariationalStrategy computes triangular solves against the inducing_values and to compute the middle_term
K_zz^{-1/2}(S - I)K_{zz}^{-1/2}. Since neither of these depend on data, we can cache them inevalmode, circumventing the need for linear solves entirely at test time for both the mean and the variance.This also simplifies onnx exporting relative to the original variational strategy (since onnx does not directly support
triangular_solve), so I added a unit test to cover easy onnx exporting of LargeBatchVariationalStrategy GP models.