Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/run_test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
pip install "pyro-ppl>=1.8";
pip install pykeops;
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
pip install onnx onnxruntime onnxscript; # For ONNX export tests
fi
- name: Run unit tests
run: |
Expand Down
41 changes: 31 additions & 10 deletions gpytorch/variational/large_batch_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ class LargeBatchVariationalStrategy(VariationalStrategy):
CPUs and consumer cards should use `VariationalStrategy` instead.
"""

def _clear_cache(self) -> None:
# Clear cached inference terms before calling parent's _clear_cache
if hasattr(self, "_cached_inv_chol_t_inducing_values"):
del self._cached_inv_chol_t_inducing_values
if hasattr(self, "_cached_middle_term"):
del self._cached_middle_term
super()._clear_cache()

def _compute_predictive_updates(
self,
chol: LinearOperator,
Expand All @@ -75,19 +83,32 @@ def _compute_predictive_updates(
inducing_values = inducing_values.type(torch.float64)

# The mean update `k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z)`
inv_chol_t_inducing_values = torch.linalg.solve_triangular(
chol.mT, inducing_values.unsqueeze(-1), upper=True, left=True
)
# Cache inv_chol_t_inducing_values and middle_term during inference (not training)
# since they don't depend on the data (through induc_data_covar)
if not self.training and hasattr(self, "_cached_inv_chol_t_inducing_values"):
inv_chol_t_inducing_values = self._cached_inv_chol_t_inducing_values
else:
inv_chol_t_inducing_values = torch.linalg.solve_triangular(
chol.mT, inducing_values.unsqueeze(-1), upper=True, left=True
)
if not self.training:
self._cached_inv_chol_t_inducing_values = inv_chol_t_inducing_values
Comment on lines +94 to +95
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would still trigger a triangular solve once at test-time. I am not sure if this is a concern?

If we use GPyTorch's cached decorator like this, we could avoid test-time triangular solves completely. It's probably a nit because we can do a forward pass before ONNX exporting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we could make it decorated and then just add a grad hook. However, the current behavior is generally in line with how other caching in GPyTorch works (e.g. exact GPs cache the first forward through in eval mode).


mean_update = (induc_data_covar.mT @ inv_chol_t_inducing_values).squeeze(-1).type(dtype)

# The grouped middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}`
middle_term = prior_covar.mul(-1).to_dense()
if variational_inducing_covar is not None:
middle_term = variational_inducing_covar.to_dense() + middle_term
middle_term = middle_term.type(torch.float64)

middle_term = torch.linalg.solve_triangular(chol, middle_term, upper=False, left=False)
middle_term = torch.linalg.solve_triangular(chol.mT, middle_term, upper=True, left=True)
if not self.training and hasattr(self, "_cached_middle_term"):
middle_term = self._cached_middle_term
else:
middle_term = prior_covar.mul(-1).to_dense()
if variational_inducing_covar is not None:
middle_term = variational_inducing_covar.to_dense() + middle_term
middle_term = middle_term.type(torch.float64)

middle_term = torch.linalg.solve_triangular(chol, middle_term, upper=False, left=False)
middle_term = torch.linalg.solve_triangular(chol.mT, middle_term, upper=True, left=True)
if not self.training:
self._cached_middle_term = middle_term

# The covariance update `K_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} K_ZX`
if diag and self.training:
Expand Down
204 changes: 204 additions & 0 deletions test/variational/test_large_batch_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,207 @@ def test_against_variational_strategy(self, train: bool = True):

def test_against_variational_strategy_eval(self):
self.test_against_variational_strategy(train=False)

def test_inference_caching(self):
"""Test that inference caching works correctly."""
torch.manual_seed(42)

model, likelihood = self._make_model_and_likelihood(
batch_shape=self.batch_shape,
strategy_cls=self.strategy_cls,
distribution_cls=self.distribution_cls,
)

# Training - caches should not exist
model.train()
likelihood.train()

train_x = torch.rand(32, 2)
_ = model(train_x)

# Check that caches don't exist in training mode
self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term"))

# Switch to eval mode
model.eval()
likelihood.eval()

# First inference call should create caches
test_x = torch.rand(5, 2)
with torch.no_grad():
_ = model(test_x)

# Check caches exist
self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term"))

# Second inference call should use cached values
cached_inv_chol = model.variational_strategy._cached_inv_chol_t_inducing_values
cached_middle = model.variational_strategy._cached_middle_term

with torch.no_grad():
_ = model(test_x)

# Verify caches are the same objects (not recomputed)
self.assertTrue(model.variational_strategy._cached_inv_chol_t_inducing_values is cached_inv_chol)
self.assertTrue(model.variational_strategy._cached_middle_term is cached_middle)

# Switching back to train mode should clear caches
model.train()
_ = model(test_x)
self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term"))

def test_onnx_export(self):
"""Test that a trained SVGP with LargeBatchVariationalStrategy can be exported to ONNX in fp64."""
import os
import tempfile

try:
import onnx
import onnxruntime as ort
except ImportError:
self.skipTest("onnx and onnxruntime required for this test")

from torch.onnx import register_custom_op_symbolic

# Create and train model in fp64
torch.manual_seed(42)
model, likelihood = self._make_model_and_likelihood(
batch_shape=self.batch_shape,
strategy_cls=self.strategy_cls,
distribution_cls=self.distribution_cls,
)
model = model.double()
likelihood = likelihood.double()

# Quick training on toy data (fp64)
train_x = torch.rand(32, 2, dtype=torch.float64)
train_y = torch.sin(train_x[:, 0]) + 0.1 * torch.randn(32, dtype=torch.float64)

model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
from gpytorch.mlls import VariationalELBO

mll = VariationalELBO(likelihood, model, num_data=train_x.size(0))

for _ in range(5):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
optimizer.step()

# Switch to eval mode
model.eval()
likelihood.eval()

# Create a wrapper that returns the mean as a tensor (required for ONNX)
class MeanPredictionWrapper(torch.nn.Module):
def __init__(self, gp_model):
super().__init__()
self.gp_model = gp_model

def forward(self, x):
output = self.gp_model(x)
return output.mean

wrapper = MeanPredictionWrapper(model)
wrapper.eval()

# Do a dummy forward to populate the inference caches
test_x = torch.rand(5, 2, dtype=torch.float64)
with torch.no_grad():
_ = wrapper(test_x)

# Verify caches are populated and are in fp64
self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term"))
self.assertEqual(model.variational_strategy._cached_inv_chol_t_inducing_values.dtype, torch.float64)
self.assertEqual(model.variational_strategy._cached_middle_term.dtype, torch.float64)

# Get reference output for comparison
with torch.no_grad():
mean_pred = wrapper(test_x)
self.assertEqual(mean_pred.shape, torch.Size([5]))
self.assertEqual(mean_pred.dtype, torch.float64)

# Register custom ONNX symbolics for ops not in default registry
def mT_symbolic(g, self):
# mT swaps the last two dimensions
tensor_type = self.type()
if tensor_type is not None and tensor_type.dim() is not None:
rank = tensor_type.dim()
perm = list(range(rank - 2)) + [rank - 1, rank - 2]
return g.op("Transpose", self, perm_i=perm)
return g.op("Transpose", self, perm_i=[1, 0])

def softplus_symbolic(g, self, beta, threshold):
# Numerically stable Softplus using Where:
# softplus(x) = x + log(1 + exp(-x)) for x > 0
# = log(1 + exp(x)) for x <= 0
scaled = g.op("Mul", self, beta)
zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64))
one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64))
Comment on lines +256 to +257
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe use torch.tensor(0.0, dtype=torch.float64) so that it does not change the shape if x is a scalar.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is just more onnx idiosyncrasies unfortunately.

condition = g.op("Greater", scaled, zero)

neg_scaled = g.op("Neg", scaled)
exp_neg = g.op("Exp", neg_scaled)
one_plus_exp_neg = g.op("Add", one, exp_neg)
log_pos = g.op("Log", one_plus_exp_neg)
result_pos = g.op("Add", scaled, log_pos)

exp_scaled = g.op("Exp", scaled)
one_plus_exp = g.op("Add", one, exp_scaled)
result_neg = g.op("Log", one_plus_exp)

stable_result = g.op("Where", condition, result_pos, result_neg)
return g.op("Div", stable_result, beta)
Comment on lines +251 to +271
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not super familiar with ONNX. Does ONNX not have softplus so that we need to implement a custom one here? Similarly, is the transpose function necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has softplus, but only an fp32 version (for whatever reason).

It supports transpose, but not .mT specifically which is used all over the place in linear_operator and would be a much bigger fix to replace everywhere there than here.


try:
register_custom_op_symbolic("aten::mT", mT_symbolic, 17)
except Exception:
pass
try:
register_custom_op_symbolic("aten::softplus", softplus_symbolic, 17)
except Exception:
pass

with tempfile.TemporaryDirectory() as tmpdir:
onnx_path = os.path.join(tmpdir, "model.onnx")

# Build export kwargs - use legacy exporter if available (PyTorch 2.9+)
export_kwargs = dict(
input_names=["input"],
output_names=["mean"],
opset_version=17,
)
# dynamo=False forces the legacy TorchScript-based exporter
if hasattr(torch.onnx.export, "__wrapped__") or torch.__version__ >= "2.9":
export_kwargs["dynamo"] = False

torch.onnx.export(
wrapper,
test_x,
onnx_path,
**export_kwargs,
)

self.assertTrue(os.path.exists(onnx_path))
self.assertGreater(os.path.getsize(onnx_path), 0)

# Verify the ONNX model input is fp64
model_onnx = onnx.load(onnx_path)
input_type = model_onnx.graph.input[0].type.tensor_type.elem_type
self.assertEqual(input_type, 11) # ONNX TensorProto.DOUBLE = 11

# Verify with onnxruntime
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
session = ort.InferenceSession(onnx_path, sess_options)
onnx_output = session.run(None, {"input": test_x.numpy()})[0]

# Compare ONNX output with PyTorch output
self.assertAllClose(torch.from_numpy(onnx_output), mean_pred, rtol=1e-5, atol=1e-5)
Loading