Skip to content

A variational strategy optimized for large batch stochastic training#2688

Closed
kayween wants to merge 1 commit intocornellius-gp:mainfrom
kayween:large-batch-variational
Closed

A variational strategy optimized for large batch stochastic training#2688
kayween wants to merge 1 commit intocornellius-gp:mainfrom
kayween:large-batch-variational

Conversation

@kayween
Copy link
Collaborator

@kayween kayween commented Dec 18, 2025

This PR is a variational strategy implementation that is optimized for large batch stochastic training on data center GPUs (in particular A100). This yields about 3.81x speed up for large batch training on A100 compared to the current variational strategy implementation.

What's Changed?

Let m be the number of inducing points and n be the number of test data points (i.e., the batch size). The complexity of the forward pass is O(m^2 n) assuming m << n.

The running time bottleneck is computing the predictive variance, which requires computing

k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX

GPyTorch. The current GPyTorch implementation computes it in this way:

interp_term = K_ZZ^{-1/2} k_ZX  # O(m^2 n) triangular solve
lazy_tensor = MatmulLinearOperator(
    interp_term.transpose(-2, -1),
    (S - I) @ interp_term,  # TWO O(m^2 n) matrix multiplications!
)

In total, there are three O(m^2 n) operations. Note that there are two matrix multiplications in the last line. This is because S - I is stored as a SumLinearOperator and S is a CholLinearOperator. The Cholesky factor of S invokes two matrix multiplications.

This PR. We could save two O(m^2 n) matrix multiplications by grouping the operations in a better way:

middle_term = K_ZZ^{-1/2} @ (S - I) @ K_ZZ^{-1/2}  # O(m^3), store middle_term as a tensor
lazy_tensor = MatmulLinearOperator(k_XZ,  middle_term @ k_ZX)  # a single O(m^2 n) matmul

In total, there is only a single O(m^2 n) matrix operations.

The caveat is that the regrouping is not numerically friendly, and so we have to do the computation in double precision. This is not an issue because data center GPUs nowadays have very good FP64 tensor cores. Empirically, we observed that FP64 matmul is just as fast as FP32 matmul on A100.

Benchmark

We run SVGP training on four large-scale UCI datasets. The results available in https://api.wandb.ai/links/kayween/lw2terwd

The test MAE and test NLL are virtually identical to the current GPyTorch implementation. This PR speeds up SVGP training by 3.81x on houseelectric.

Copilot AI review requested due to automatic review settings December 18, 2025 04:40
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces LargeBatchVariationalStrategy, a performance-optimized variant of the standard variational strategy designed for large batch training on data center GPUs (particularly A100). The key optimization is regrouping matrix operations to compute the middle term K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} in double precision, reducing the number of O(m²n) operations from three to one. This yields a 3.81x speedup on large batch training according to benchmarks on the houseelectric dataset.

Key changes:

  • New LargeBatchVariationalStrategy class that overrides the forward method to use a more efficient computation order
  • Uses double precision for the middle term computation to maintain numerical stability despite the regrouping
  • Adds basic test coverage comparing outputs and gradients with the standard VariationalStrategy in training mode

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
gpytorch/variational/large_batch_variational_strategy.py New strategy implementation with optimized forward pass using regrouped operations and double precision for numerical stability
test/variational/test_large_batch_variational_strategy.py Test file verifying equivalence with VariationalStrategy in training mode by comparing outputs and gradients
gpytorch/variational/init.py Exports the new LargeBatchVariationalStrategy class

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +95 to +96
class TestLargeBatchVariationalStrategy(unittest.TestCase, BaseTestCase, CustomVariationalStrategyMixin):
variational_strategy_class = LargeBatchVariationalStrategy
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test class only tests training mode behavior but lacks evaluation mode tests. The standard VariationalTestCase includes test_eval_iteration which verifies caching behavior and forward pass correctness in eval mode. This is important because eval mode may have different code paths and the caching mechanism needs to be verified. Consider either inheriting from VariationalTestCase (like test_variational_strategy.py does) or adding explicit eval mode tests.

Copilot uses AI. Check for mistakes.
Comment on lines +59 to +62
vec = torch.linalg.solve_triangular(
L.mT.type(full_inputs.dtype), inducing_values.unsqueeze(-1), upper=True, left=True
)
predictive_mean = (induc_data_covar.mT @ vec).squeeze(-1) + test_mean
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent precision handling between mean and covariance computation. For the mean computation (lines 59-62), L.mT is cast to full_inputs.dtype (likely float32), but for the covariance computation (lines 70-72), L remains in double precision. This inconsistency makes the code harder to maintain and understand. According to the PR description, the double precision computation is intentional for numerical stability in the regrouped operations, but this should apply consistently. Consider either:

  1. Keeping L in double precision for both computations (more numerically stable)
  2. Adding comments explaining why different precisions are used
  3. If float32 is sufficient for the mean computation, documenting why

Copilot uses AI. Check for mistakes.
self,
inducing_points,
variational_strategy_class: type[VariationalStrategy] = VariationalStrategy,
random_initialization: bool = False,
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The random_initialization parameter is defined in the init signature but never used. This parameter should either be removed or implemented if it serves a purpose.

Suggested change
random_initialization: bool = False,

Copilot uses AI. Check for mistakes.
Comment on lines +19 to +21
r"""A lightweight and performant variational strategy that is optimized for large batch training.

This class groups the middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}` in double precision.
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring lacks important details compared to the parent VariationalStrategy class. It should include:

  1. Parameter descriptions for all arguments (model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val)
  2. When to use this strategy vs the standard VariationalStrategy
  3. The trade-off of using double precision (performance benefit on A100 but potential memory overhead)
  4. References or citations for the optimization approach
Suggested change
r"""A lightweight and performant variational strategy that is optimized for large batch training.
This class groups the middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}` in double precision.
r"""A lightweight and performant variational strategy optimized for large-batch training.
This strategy has the same interface as :class:`gpytorch.variational.VariationalStrategy`,
but changes how the variational covariance term is computed for better numerical
stability and performance on modern accelerators.
In particular, it groups and evaluates the middle term
``K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}``
in a higher-precision linear algebra dtype (see :data:`gpytorch.settings._linalg_dtype_cholesky`)
before mapping results back to the model's dtype. This can improve the stability of the
Cholesky-based solves used in large-batch training, especially when many inducing points
are used or when kernels produce nearly singular covariance matrices.
Args:
model: Approximate GP model that uses this variational strategy. This is typically
a subclass of :class:`gpytorch.models.ApproximateGP`, and it provides the prior
mean and covariance through its ``forward`` method.
inducing_points: Tensor of inducing point locations :math:`Z` with shape
``(..., M, D)``, where ``M`` is the number of inducing points and ``D`` is the
input dimensionality. These points define the variational approximation to the
prior process.
variational_distribution: Variational distribution over the inducing values
:math:`q(u)`. This is typically an instance of
:class:`gpytorch.variational.VariationalDistribution` (e.g.,
:class:`gpytorch.variational.CholeskyVariationalDistribution`) that parameterizes
the mean and covariance of the inducing outputs.
learn_inducing_locations: If ``True``, the inducing point locations are treated as
learnable parameters and optimized jointly with the model hyperparameters. If
``False``, the inducing locations are fixed.
jitter_val: Small non-negative scalar added to the diagonal of covariance matrices
for numerical stability during Cholesky factorizations. This is used when
computing the inducing-inducing and data-data covariances.
When to use:
Use :class:`LargeBatchVariationalStrategy` instead of the standard
:class:`gpytorch.variational.VariationalStrategy` when training with very large batch
sizes or many inducing points on hardware that benefits from mixed-precision or
double-precision linear algebra (e.g., NVIDIA A100 GPUs). By performing the key
``K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}`` computation in a higher-precision dtype, this
strategy can reduce numerical issues and improve throughput for large-batch
variational training without changing the overall objective.
Double precision trade-off:
Evaluating the grouped middle term in double precision (or a higher-precision
Cholesky dtype) can yield substantial performance benefits and more stable
optimization on accelerators that provide efficient double-precision units (such as
A100 GPUs). However, this also increases memory usage, since intermediate matrices
(e.g., the middle term and associated solves) are stored in the higher-precision
dtype. On memory-constrained devices or very large models, this trade-off should be
considered when selecting batch size and number of inducing points.
References:
- M. Titsias. *Variational Learning of Inducing Variables in Sparse Gaussian Processes*.
AISTATS, 2009.
- J. Hensman, N. Fusi, N. D. Lawrence. *Gaussian Processes for Big Data*. UAI, 2013.
(Variational inducing-point methods for large-scale GP models.)

Copilot uses AI. Check for mistakes.
# TODO: Use a hook fo this
try:
pop_from_cache_ignore_args(self, "cholesky_factor")
except CachingError:
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Suggested change
except CachingError:
except CachingError:
# If cache eviction fails, fall back to recomputing the Cholesky factor below.

Copilot uses AI. Check for mistakes.
@kayween
Copy link
Collaborator Author

kayween commented Dec 19, 2025

I'll push a different branch for this implementation --- closing this PR for now.

@kayween kayween closed this Dec 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants