A variational strategy optimized for large batch stochastic training#2688
A variational strategy optimized for large batch stochastic training#2688kayween wants to merge 1 commit intocornellius-gp:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces LargeBatchVariationalStrategy, a performance-optimized variant of the standard variational strategy designed for large batch training on data center GPUs (particularly A100). The key optimization is regrouping matrix operations to compute the middle term K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} in double precision, reducing the number of O(m²n) operations from three to one. This yields a 3.81x speedup on large batch training according to benchmarks on the houseelectric dataset.
Key changes:
- New
LargeBatchVariationalStrategyclass that overrides the forward method to use a more efficient computation order - Uses double precision for the middle term computation to maintain numerical stability despite the regrouping
- Adds basic test coverage comparing outputs and gradients with the standard
VariationalStrategyin training mode
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| gpytorch/variational/large_batch_variational_strategy.py | New strategy implementation with optimized forward pass using regrouped operations and double precision for numerical stability |
| test/variational/test_large_batch_variational_strategy.py | Test file verifying equivalence with VariationalStrategy in training mode by comparing outputs and gradients |
| gpytorch/variational/init.py | Exports the new LargeBatchVariationalStrategy class |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class TestLargeBatchVariationalStrategy(unittest.TestCase, BaseTestCase, CustomVariationalStrategyMixin): | ||
| variational_strategy_class = LargeBatchVariationalStrategy |
There was a problem hiding this comment.
The test class only tests training mode behavior but lacks evaluation mode tests. The standard VariationalTestCase includes test_eval_iteration which verifies caching behavior and forward pass correctness in eval mode. This is important because eval mode may have different code paths and the caching mechanism needs to be verified. Consider either inheriting from VariationalTestCase (like test_variational_strategy.py does) or adding explicit eval mode tests.
| vec = torch.linalg.solve_triangular( | ||
| L.mT.type(full_inputs.dtype), inducing_values.unsqueeze(-1), upper=True, left=True | ||
| ) | ||
| predictive_mean = (induc_data_covar.mT @ vec).squeeze(-1) + test_mean |
There was a problem hiding this comment.
Inconsistent precision handling between mean and covariance computation. For the mean computation (lines 59-62), L.mT is cast to full_inputs.dtype (likely float32), but for the covariance computation (lines 70-72), L remains in double precision. This inconsistency makes the code harder to maintain and understand. According to the PR description, the double precision computation is intentional for numerical stability in the regrouped operations, but this should apply consistently. Consider either:
- Keeping L in double precision for both computations (more numerically stable)
- Adding comments explaining why different precisions are used
- If float32 is sufficient for the mean computation, documenting why
| self, | ||
| inducing_points, | ||
| variational_strategy_class: type[VariationalStrategy] = VariationalStrategy, | ||
| random_initialization: bool = False, |
There was a problem hiding this comment.
The random_initialization parameter is defined in the init signature but never used. This parameter should either be removed or implemented if it serves a purpose.
| random_initialization: bool = False, |
| r"""A lightweight and performant variational strategy that is optimized for large batch training. | ||
|
|
||
| This class groups the middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}` in double precision. |
There was a problem hiding this comment.
The docstring lacks important details compared to the parent VariationalStrategy class. It should include:
- Parameter descriptions for all arguments (model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val)
- When to use this strategy vs the standard VariationalStrategy
- The trade-off of using double precision (performance benefit on A100 but potential memory overhead)
- References or citations for the optimization approach
| r"""A lightweight and performant variational strategy that is optimized for large batch training. | |
| This class groups the middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}` in double precision. | |
| r"""A lightweight and performant variational strategy optimized for large-batch training. | |
| This strategy has the same interface as :class:`gpytorch.variational.VariationalStrategy`, | |
| but changes how the variational covariance term is computed for better numerical | |
| stability and performance on modern accelerators. | |
| In particular, it groups and evaluates the middle term | |
| ``K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}`` | |
| in a higher-precision linear algebra dtype (see :data:`gpytorch.settings._linalg_dtype_cholesky`) | |
| before mapping results back to the model's dtype. This can improve the stability of the | |
| Cholesky-based solves used in large-batch training, especially when many inducing points | |
| are used or when kernels produce nearly singular covariance matrices. | |
| Args: | |
| model: Approximate GP model that uses this variational strategy. This is typically | |
| a subclass of :class:`gpytorch.models.ApproximateGP`, and it provides the prior | |
| mean and covariance through its ``forward`` method. | |
| inducing_points: Tensor of inducing point locations :math:`Z` with shape | |
| ``(..., M, D)``, where ``M`` is the number of inducing points and ``D`` is the | |
| input dimensionality. These points define the variational approximation to the | |
| prior process. | |
| variational_distribution: Variational distribution over the inducing values | |
| :math:`q(u)`. This is typically an instance of | |
| :class:`gpytorch.variational.VariationalDistribution` (e.g., | |
| :class:`gpytorch.variational.CholeskyVariationalDistribution`) that parameterizes | |
| the mean and covariance of the inducing outputs. | |
| learn_inducing_locations: If ``True``, the inducing point locations are treated as | |
| learnable parameters and optimized jointly with the model hyperparameters. If | |
| ``False``, the inducing locations are fixed. | |
| jitter_val: Small non-negative scalar added to the diagonal of covariance matrices | |
| for numerical stability during Cholesky factorizations. This is used when | |
| computing the inducing-inducing and data-data covariances. | |
| When to use: | |
| Use :class:`LargeBatchVariationalStrategy` instead of the standard | |
| :class:`gpytorch.variational.VariationalStrategy` when training with very large batch | |
| sizes or many inducing points on hardware that benefits from mixed-precision or | |
| double-precision linear algebra (e.g., NVIDIA A100 GPUs). By performing the key | |
| ``K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}`` computation in a higher-precision dtype, this | |
| strategy can reduce numerical issues and improve throughput for large-batch | |
| variational training without changing the overall objective. | |
| Double precision trade-off: | |
| Evaluating the grouped middle term in double precision (or a higher-precision | |
| Cholesky dtype) can yield substantial performance benefits and more stable | |
| optimization on accelerators that provide efficient double-precision units (such as | |
| A100 GPUs). However, this also increases memory usage, since intermediate matrices | |
| (e.g., the middle term and associated solves) are stored in the higher-precision | |
| dtype. On memory-constrained devices or very large models, this trade-off should be | |
| considered when selecting batch size and number of inducing points. | |
| References: | |
| - M. Titsias. *Variational Learning of Inducing Variables in Sparse Gaussian Processes*. | |
| AISTATS, 2009. | |
| - J. Hensman, N. Fusi, N. D. Lawrence. *Gaussian Processes for Big Data*. UAI, 2013. | |
| (Variational inducing-point methods for large-scale GP models.) |
| # TODO: Use a hook fo this | ||
| try: | ||
| pop_from_cache_ignore_args(self, "cholesky_factor") | ||
| except CachingError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except CachingError: | |
| except CachingError: | |
| # If cache eviction fails, fall back to recomputing the Cholesky factor below. |
|
I'll push a different branch for this implementation --- closing this PR for now. |
This PR is a variational strategy implementation that is optimized for large batch stochastic training on data center GPUs (in particular A100). This yields about 3.81x speed up for large batch training on A100 compared to the current variational strategy implementation.
What's Changed?
Let
mbe the number of inducing points andnbe the number of test data points (i.e., the batch size). The complexity of the forward pass is O(m^2 n) assumingm << n.The running time bottleneck is computing the predictive variance, which requires computing
GPyTorch. The current GPyTorch implementation computes it in this way:
In total, there are three O(m^2 n) operations. Note that there are two matrix multiplications in the last line. This is because
S - Iis stored as aSumLinearOperatorandSis aCholLinearOperator. The Cholesky factor ofSinvokes two matrix multiplications.This PR. We could save two O(m^2 n) matrix multiplications by grouping the operations in a better way:
In total, there is only a single O(m^2 n) matrix operations.
The caveat is that the regrouping is not numerically friendly, and so we have to do the computation in double precision. This is not an issue because data center GPUs nowadays have very good FP64 tensor cores. Empirically, we observed that FP64 matmul is just as fast as FP32 matmul on A100.
Benchmark
We run SVGP training on four large-scale UCI datasets. The results available in https://api.wandb.ai/links/kayween/lw2terwd
The test MAE and test NLL are virtually identical to the current GPyTorch implementation. This PR speeds up SVGP training by 3.81x on houseelectric.