Skip to content

Commit 1726773

Browse files
Merge pull request #2698 from cornellius-gp/svgp_caches
LargeBatchVariationalStrategy caches triangular solves for test time
2 parents ff32436 + 5a908df commit 1726773

File tree

3 files changed

+236
-10
lines changed

3 files changed

+236
-10
lines changed

.github/workflows/run_test_suite.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ jobs:
5858
pip install "pyro-ppl>=1.8";
5959
pip install pykeops;
6060
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
61+
pip install onnx onnxruntime onnxscript; # For ONNX export tests
6162
fi
6263
- name: Run unit tests
6364
run: |

gpytorch/variational/large_batch_variational_strategy.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class LargeBatchVariationalStrategy(VariationalStrategy):
5757
CPUs and consumer cards should use `VariationalStrategy` instead.
5858
"""
5959

60+
def _clear_cache(self) -> None:
61+
# Clear cached inference terms before calling parent's _clear_cache
62+
if hasattr(self, "_cached_inv_chol_t_inducing_values"):
63+
del self._cached_inv_chol_t_inducing_values
64+
if hasattr(self, "_cached_middle_term"):
65+
del self._cached_middle_term
66+
super()._clear_cache()
67+
6068
def _compute_predictive_updates(
6169
self,
6270
chol: LinearOperator,
@@ -75,19 +83,32 @@ def _compute_predictive_updates(
7583
inducing_values = inducing_values.type(torch.float64)
7684

7785
# The mean update `k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z)`
78-
inv_chol_t_inducing_values = torch.linalg.solve_triangular(
79-
chol.mT, inducing_values.unsqueeze(-1), upper=True, left=True
80-
)
86+
# Cache inv_chol_t_inducing_values and middle_term during inference (not training)
87+
# since they don't depend on the data (through induc_data_covar)
88+
if not self.training and hasattr(self, "_cached_inv_chol_t_inducing_values"):
89+
inv_chol_t_inducing_values = self._cached_inv_chol_t_inducing_values
90+
else:
91+
inv_chol_t_inducing_values = torch.linalg.solve_triangular(
92+
chol.mT, inducing_values.unsqueeze(-1), upper=True, left=True
93+
)
94+
if not self.training:
95+
self._cached_inv_chol_t_inducing_values = inv_chol_t_inducing_values
96+
8197
mean_update = (induc_data_covar.mT @ inv_chol_t_inducing_values).squeeze(-1).type(dtype)
8298

8399
# The grouped middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}`
84-
middle_term = prior_covar.mul(-1).to_dense()
85-
if variational_inducing_covar is not None:
86-
middle_term = variational_inducing_covar.to_dense() + middle_term
87-
middle_term = middle_term.type(torch.float64)
88-
89-
middle_term = torch.linalg.solve_triangular(chol, middle_term, upper=False, left=False)
90-
middle_term = torch.linalg.solve_triangular(chol.mT, middle_term, upper=True, left=True)
100+
if not self.training and hasattr(self, "_cached_middle_term"):
101+
middle_term = self._cached_middle_term
102+
else:
103+
middle_term = prior_covar.mul(-1).to_dense()
104+
if variational_inducing_covar is not None:
105+
middle_term = variational_inducing_covar.to_dense() + middle_term
106+
middle_term = middle_term.type(torch.float64)
107+
108+
middle_term = torch.linalg.solve_triangular(chol, middle_term, upper=False, left=False)
109+
middle_term = torch.linalg.solve_triangular(chol.mT, middle_term, upper=True, left=True)
110+
if not self.training:
111+
self._cached_middle_term = middle_term
91112

92113
# The covariance update `K_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} K_ZX`
93114
if diag and self.training:

test/variational/test_large_batch_variational_strategy.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from gpytorch.mlls import VariationalELBO
67
from gpytorch.test.base_test_case import BaseTestCase
78
from gpytorch.variational.large_batch_variational_strategy import LargeBatchVariationalStrategy, QuadFormDiagonal
89
from gpytorch.variational.variational_strategy import VariationalStrategy
@@ -111,3 +112,206 @@ def test_against_variational_strategy(self, train: bool = True):
111112

112113
def test_against_variational_strategy_eval(self):
113114
self.test_against_variational_strategy(train=False)
115+
116+
def test_inference_caching(self):
117+
"""Test that inference caching works correctly."""
118+
torch.manual_seed(42)
119+
120+
model, likelihood = self._make_model_and_likelihood(
121+
batch_shape=self.batch_shape,
122+
strategy_cls=self.strategy_cls,
123+
distribution_cls=self.distribution_cls,
124+
)
125+
126+
# Training - caches should not exist
127+
model.train()
128+
likelihood.train()
129+
130+
train_x = torch.rand(32, 2)
131+
_ = model(train_x)
132+
133+
# Check that caches don't exist in training mode
134+
self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
135+
self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term"))
136+
137+
# Switch to eval mode
138+
model.eval()
139+
likelihood.eval()
140+
141+
# First inference call should create caches
142+
test_x = torch.rand(5, 2)
143+
with torch.no_grad():
144+
_ = model(test_x)
145+
146+
# Check caches exist
147+
self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
148+
self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term"))
149+
150+
# Second inference call should use cached values
151+
cached_inv_chol = model.variational_strategy._cached_inv_chol_t_inducing_values
152+
cached_middle = model.variational_strategy._cached_middle_term
153+
154+
with torch.no_grad():
155+
_ = model(test_x)
156+
157+
# Verify caches are the same objects (not recomputed)
158+
self.assertTrue(model.variational_strategy._cached_inv_chol_t_inducing_values is cached_inv_chol)
159+
self.assertTrue(model.variational_strategy._cached_middle_term is cached_middle)
160+
161+
# Switching back to train mode should clear caches
162+
model.train()
163+
_ = model(test_x)
164+
self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
165+
self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term"))
166+
167+
def test_onnx_export(self):
168+
"""Test that a trained SVGP with LargeBatchVariationalStrategy can be exported to ONNX in fp64."""
169+
import os
170+
import tempfile
171+
172+
try:
173+
import onnx
174+
import onnxruntime as ort
175+
except ImportError:
176+
self.skipTest("onnx and onnxruntime required for this test")
177+
178+
from torch.onnx import register_custom_op_symbolic
179+
180+
# Create and train model in fp64
181+
torch.manual_seed(42)
182+
model, likelihood = self._make_model_and_likelihood(
183+
batch_shape=self.batch_shape,
184+
strategy_cls=self.strategy_cls,
185+
distribution_cls=self.distribution_cls,
186+
)
187+
model = model.double()
188+
likelihood = likelihood.double()
189+
190+
# Quick training on toy data (fp64)
191+
train_x = torch.rand(32, 2, dtype=torch.float64)
192+
train_y = torch.sin(train_x[:, 0]) + 0.1 * torch.randn(32, dtype=torch.float64)
193+
194+
model.train()
195+
likelihood.train()
196+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
197+
198+
mll = VariationalELBO(likelihood, model, num_data=train_x.size(0))
199+
200+
for _ in range(5):
201+
optimizer.zero_grad()
202+
output = model(train_x)
203+
loss = -mll(output, train_y)
204+
loss.backward()
205+
optimizer.step()
206+
207+
# Switch to eval mode
208+
model.eval()
209+
likelihood.eval()
210+
211+
# Create a wrapper that returns the mean as a tensor (required for ONNX)
212+
class MeanPredictionWrapper(torch.nn.Module):
213+
def __init__(self, gp_model):
214+
super().__init__()
215+
self.gp_model = gp_model
216+
217+
def forward(self, x):
218+
output = self.gp_model(x)
219+
return output.mean
220+
221+
wrapper = MeanPredictionWrapper(model)
222+
wrapper.eval()
223+
224+
# Do a dummy forward to populate the inference caches
225+
test_x = torch.rand(5, 2, dtype=torch.float64)
226+
with torch.no_grad():
227+
_ = wrapper(test_x)
228+
229+
# Verify caches are populated and are in fp64
230+
self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values"))
231+
self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term"))
232+
self.assertEqual(model.variational_strategy._cached_inv_chol_t_inducing_values.dtype, torch.float64)
233+
self.assertEqual(model.variational_strategy._cached_middle_term.dtype, torch.float64)
234+
235+
# Get reference output for comparison
236+
with torch.no_grad():
237+
mean_pred = wrapper(test_x)
238+
self.assertEqual(mean_pred.shape, torch.Size([5]))
239+
self.assertEqual(mean_pred.dtype, torch.float64)
240+
241+
# Register custom ONNX symbolics for ops not in default registry
242+
def mT_symbolic(g, self):
243+
# mT swaps the last two dimensions
244+
tensor_type = self.type()
245+
if tensor_type is not None and tensor_type.dim() is not None:
246+
rank = tensor_type.dim()
247+
perm = list(range(rank - 2)) + [rank - 1, rank - 2]
248+
return g.op("Transpose", self, perm_i=perm)
249+
return g.op("Transpose", self, perm_i=[1, 0])
250+
251+
def softplus_symbolic(g, self, beta, threshold):
252+
# Numerically stable Softplus using Where:
253+
# softplus(x) = x + log(1 + exp(-x)) for x > 0
254+
# = log(1 + exp(x)) for x <= 0
255+
scaled = g.op("Mul", self, beta)
256+
zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64))
257+
one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64))
258+
condition = g.op("Greater", scaled, zero)
259+
260+
neg_scaled = g.op("Neg", scaled)
261+
exp_neg = g.op("Exp", neg_scaled)
262+
one_plus_exp_neg = g.op("Add", one, exp_neg)
263+
log_pos = g.op("Log", one_plus_exp_neg)
264+
result_pos = g.op("Add", scaled, log_pos)
265+
266+
exp_scaled = g.op("Exp", scaled)
267+
one_plus_exp = g.op("Add", one, exp_scaled)
268+
result_neg = g.op("Log", one_plus_exp)
269+
270+
stable_result = g.op("Where", condition, result_pos, result_neg)
271+
return g.op("Div", stable_result, beta)
272+
273+
try:
274+
register_custom_op_symbolic("aten::mT", mT_symbolic, 17)
275+
except Exception:
276+
pass
277+
try:
278+
register_custom_op_symbolic("aten::softplus", softplus_symbolic, 17)
279+
except Exception:
280+
pass
281+
282+
with tempfile.TemporaryDirectory() as tmpdir:
283+
onnx_path = os.path.join(tmpdir, "model.onnx")
284+
285+
# Build export kwargs - use legacy exporter if available (PyTorch 2.9+)
286+
export_kwargs = dict(
287+
input_names=["input"],
288+
output_names=["mean"],
289+
opset_version=17,
290+
)
291+
# dynamo=False forces the legacy TorchScript-based exporter
292+
if hasattr(torch.onnx.export, "__wrapped__") or torch.__version__ >= "2.9":
293+
export_kwargs["dynamo"] = False
294+
295+
torch.onnx.export(
296+
wrapper,
297+
test_x,
298+
onnx_path,
299+
**export_kwargs,
300+
)
301+
302+
self.assertTrue(os.path.exists(onnx_path))
303+
self.assertGreater(os.path.getsize(onnx_path), 0)
304+
305+
# Verify the ONNX model input is fp64
306+
model_onnx = onnx.load(onnx_path)
307+
input_type = model_onnx.graph.input[0].type.tensor_type.elem_type
308+
self.assertEqual(input_type, 11) # ONNX TensorProto.DOUBLE = 11
309+
310+
# Verify with onnxruntime
311+
sess_options = ort.SessionOptions()
312+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
313+
session = ort.InferenceSession(onnx_path, sess_options)
314+
onnx_output = session.run(None, {"input": test_x.numpy()})[0]
315+
316+
# Compare ONNX output with PyTorch output
317+
self.assertAllClose(torch.from_numpy(onnx_output), mean_pred, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)