-
Notifications
You must be signed in to change notification settings - Fork 587
LargeBatchVariationalStrategy caches triangular solves for test time #2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
e216666
fb220b7
1719345
2f2c3fe
5a908df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,14 @@ class LargeBatchVariationalStrategy(VariationalStrategy): | |
| CPUs and consumer cards should use `VariationalStrategy` instead. | ||
| """ | ||
|
|
||
| def _clear_cache(self) -> None: | ||
| # Clear cached inference terms before calling parent's _clear_cache | ||
| if hasattr(self, "_cached_inv_chol_t_inducing_values"): | ||
| del self._cached_inv_chol_t_inducing_values | ||
| if hasattr(self, "_cached_middle_term"): | ||
| del self._cached_middle_term | ||
| super()._clear_cache() | ||
|
|
||
| def _compute_predictive_updates( | ||
| self, | ||
| chol: LinearOperator, | ||
|
|
@@ -75,19 +83,32 @@ def _compute_predictive_updates( | |
| inducing_values = inducing_values.type(torch.float64) | ||
|
|
||
| # The mean update `k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z)` | ||
| inv_chol_t_inducing_values = torch.linalg.solve_triangular( | ||
| chol.mT, inducing_values.unsqueeze(-1), upper=True, left=True | ||
| ) | ||
| # Cache inv_chol_t_inducing_values and middle_term during inference (not training) | ||
| # since they don't depend on the data (through induc_data_covar) | ||
| if not self.training and hasattr(self, "_cached_inv_chol_t_inducing_values"): | ||
| inv_chol_t_inducing_values = self._cached_inv_chol_t_inducing_values | ||
| else: | ||
| inv_chol_t_inducing_values = torch.linalg.solve_triangular( | ||
| chol.mT, inducing_values.unsqueeze(-1), upper=True, left=True | ||
| ) | ||
| if not self.training: | ||
| self._cached_inv_chol_t_inducing_values = inv_chol_t_inducing_values | ||
|
Comment on lines
+94
to
+95
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would still trigger a triangular solve once at test-time. I am not sure if this is a concern? If we use GPyTorch's
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah we could make it decorated and then just add a grad hook. However, the current behavior is generally in line with how other caching in GPyTorch works (e.g. exact GPs cache the first forward through in eval mode). |
||
|
|
||
| mean_update = (induc_data_covar.mT @ inv_chol_t_inducing_values).squeeze(-1).type(dtype) | ||
|
|
||
| # The grouped middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}` | ||
| middle_term = prior_covar.mul(-1).to_dense() | ||
| if variational_inducing_covar is not None: | ||
| middle_term = variational_inducing_covar.to_dense() + middle_term | ||
| middle_term = middle_term.type(torch.float64) | ||
|
|
||
| middle_term = torch.linalg.solve_triangular(chol, middle_term, upper=False, left=False) | ||
| middle_term = torch.linalg.solve_triangular(chol.mT, middle_term, upper=True, left=True) | ||
| if not self.training and hasattr(self, "_cached_middle_term"): | ||
| middle_term = self._cached_middle_term | ||
| else: | ||
| middle_term = prior_covar.mul(-1).to_dense() | ||
| if variational_inducing_covar is not None: | ||
| middle_term = variational_inducing_covar.to_dense() + middle_term | ||
| middle_term = middle_term.type(torch.float64) | ||
|
|
||
| middle_term = torch.linalg.solve_triangular(chol, middle_term, upper=False, left=False) | ||
| middle_term = torch.linalg.solve_triangular(chol.mT, middle_term, upper=True, left=True) | ||
| if not self.training: | ||
| self._cached_middle_term = middle_term | ||
|
|
||
| # The covariance update `K_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} K_ZX` | ||
| if diag and self.training: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -111,3 +111,207 @@ def test_against_variational_strategy(self, train: bool = True): | |
|
|
||
| def test_against_variational_strategy_eval(self): | ||
| self.test_against_variational_strategy(train=False) | ||
|
|
||
| def test_inference_caching(self): | ||
| """Test that inference caching works correctly.""" | ||
| torch.manual_seed(42) | ||
|
|
||
| model, likelihood = self._make_model_and_likelihood( | ||
| batch_shape=self.batch_shape, | ||
| strategy_cls=self.strategy_cls, | ||
| distribution_cls=self.distribution_cls, | ||
| ) | ||
|
|
||
| # Training - caches should not exist | ||
| model.train() | ||
| likelihood.train() | ||
|
|
||
| train_x = torch.rand(32, 2) | ||
| _ = model(train_x) | ||
|
|
||
| # Check that caches don't exist in training mode | ||
| self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) | ||
| self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term")) | ||
|
|
||
| # Switch to eval mode | ||
| model.eval() | ||
| likelihood.eval() | ||
|
|
||
| # First inference call should create caches | ||
| test_x = torch.rand(5, 2) | ||
| with torch.no_grad(): | ||
| _ = model(test_x) | ||
|
|
||
| # Check caches exist | ||
| self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) | ||
| self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term")) | ||
|
|
||
| # Second inference call should use cached values | ||
| cached_inv_chol = model.variational_strategy._cached_inv_chol_t_inducing_values | ||
| cached_middle = model.variational_strategy._cached_middle_term | ||
|
|
||
| with torch.no_grad(): | ||
| _ = model(test_x) | ||
|
|
||
| # Verify caches are the same objects (not recomputed) | ||
| self.assertTrue(model.variational_strategy._cached_inv_chol_t_inducing_values is cached_inv_chol) | ||
| self.assertTrue(model.variational_strategy._cached_middle_term is cached_middle) | ||
jacobrgardner marked this conversation as resolved.
Show resolved
Hide resolved
jacobrgardner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Switching back to train mode should clear caches | ||
| model.train() | ||
| _ = model(test_x) | ||
| self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) | ||
| self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term")) | ||
|
|
||
| def test_onnx_export(self): | ||
| """Test that a trained SVGP with LargeBatchVariationalStrategy can be exported to ONNX in fp64.""" | ||
| import os | ||
| import tempfile | ||
|
|
||
| try: | ||
| import onnx | ||
| import onnxruntime as ort | ||
| except ImportError: | ||
| self.skipTest("onnx and onnxruntime required for this test") | ||
|
|
||
| from torch.onnx import register_custom_op_symbolic | ||
|
|
||
| # Create and train model in fp64 | ||
| torch.manual_seed(42) | ||
| model, likelihood = self._make_model_and_likelihood( | ||
| batch_shape=self.batch_shape, | ||
| strategy_cls=self.strategy_cls, | ||
| distribution_cls=self.distribution_cls, | ||
| ) | ||
| model = model.double() | ||
| likelihood = likelihood.double() | ||
|
|
||
| # Quick training on toy data (fp64) | ||
| train_x = torch.rand(32, 2, dtype=torch.float64) | ||
| train_y = torch.sin(train_x[:, 0]) + 0.1 * torch.randn(32, dtype=torch.float64) | ||
|
|
||
| model.train() | ||
| likelihood.train() | ||
| optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
| from gpytorch.mlls import VariationalELBO | ||
jacobrgardner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| mll = VariationalELBO(likelihood, model, num_data=train_x.size(0)) | ||
|
|
||
| for _ in range(5): | ||
| optimizer.zero_grad() | ||
| output = model(train_x) | ||
| loss = -mll(output, train_y) | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| # Switch to eval mode | ||
| model.eval() | ||
| likelihood.eval() | ||
|
|
||
| # Create a wrapper that returns the mean as a tensor (required for ONNX) | ||
| class MeanPredictionWrapper(torch.nn.Module): | ||
| def __init__(self, gp_model): | ||
| super().__init__() | ||
| self.gp_model = gp_model | ||
|
|
||
| def forward(self, x): | ||
| output = self.gp_model(x) | ||
| return output.mean | ||
|
|
||
| wrapper = MeanPredictionWrapper(model) | ||
| wrapper.eval() | ||
|
|
||
| # Do a dummy forward to populate the inference caches | ||
| test_x = torch.rand(5, 2, dtype=torch.float64) | ||
| with torch.no_grad(): | ||
| _ = wrapper(test_x) | ||
|
|
||
| # Verify caches are populated and are in fp64 | ||
| self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) | ||
| self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term")) | ||
| self.assertEqual(model.variational_strategy._cached_inv_chol_t_inducing_values.dtype, torch.float64) | ||
| self.assertEqual(model.variational_strategy._cached_middle_term.dtype, torch.float64) | ||
|
|
||
| # Get reference output for comparison | ||
| with torch.no_grad(): | ||
| mean_pred = wrapper(test_x) | ||
| self.assertEqual(mean_pred.shape, torch.Size([5])) | ||
| self.assertEqual(mean_pred.dtype, torch.float64) | ||
|
|
||
| # Register custom ONNX symbolics for ops not in default registry | ||
| def mT_symbolic(g, self): | ||
| # mT swaps the last two dimensions | ||
| tensor_type = self.type() | ||
| if tensor_type is not None and tensor_type.dim() is not None: | ||
| rank = tensor_type.dim() | ||
| perm = list(range(rank - 2)) + [rank - 1, rank - 2] | ||
| return g.op("Transpose", self, perm_i=perm) | ||
| return g.op("Transpose", self, perm_i=[1, 0]) | ||
kayween marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def softplus_symbolic(g, self, beta, threshold): | ||
| # Numerically stable Softplus using Where: | ||
| # softplus(x) = x + log(1 + exp(-x)) for x > 0 | ||
| # = log(1 + exp(x)) for x <= 0 | ||
| scaled = g.op("Mul", self, beta) | ||
| zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64)) | ||
| one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64)) | ||
|
Comment on lines
+256
to
+257
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this is just more onnx idiosyncrasies unfortunately. |
||
| condition = g.op("Greater", scaled, zero) | ||
|
|
||
| neg_scaled = g.op("Neg", scaled) | ||
| exp_neg = g.op("Exp", neg_scaled) | ||
| one_plus_exp_neg = g.op("Add", one, exp_neg) | ||
| log_pos = g.op("Log", one_plus_exp_neg) | ||
| result_pos = g.op("Add", scaled, log_pos) | ||
|
|
||
| exp_scaled = g.op("Exp", scaled) | ||
| one_plus_exp = g.op("Add", one, exp_scaled) | ||
| result_neg = g.op("Log", one_plus_exp) | ||
|
|
||
| stable_result = g.op("Where", condition, result_pos, result_neg) | ||
| return g.op("Div", stable_result, beta) | ||
|
Comment on lines
+251
to
+271
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not super familiar with ONNX. Does ONNX not have
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has softplus, but only an fp32 version (for whatever reason). It supports transpose, but not |
||
|
|
||
| try: | ||
| register_custom_op_symbolic("aten::mT", mT_symbolic, 17) | ||
| except Exception: | ||
| pass | ||
| try: | ||
| register_custom_op_symbolic("aten::softplus", softplus_symbolic, 17) | ||
| except Exception: | ||
jacobrgardner marked this conversation as resolved.
Show resolved
Hide resolved
jacobrgardner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pass | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| onnx_path = os.path.join(tmpdir, "model.onnx") | ||
|
|
||
| # Build export kwargs - use legacy exporter if available (PyTorch 2.9+) | ||
jacobrgardner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| export_kwargs = dict( | ||
| input_names=["input"], | ||
| output_names=["mean"], | ||
| opset_version=17, | ||
| ) | ||
| # dynamo=False forces the legacy TorchScript-based exporter | ||
| if hasattr(torch.onnx.export, "__wrapped__") or torch.__version__ >= "2.9": | ||
| export_kwargs["dynamo"] = False | ||
|
|
||
| torch.onnx.export( | ||
| wrapper, | ||
| test_x, | ||
| onnx_path, | ||
| **export_kwargs, | ||
| ) | ||
|
|
||
| self.assertTrue(os.path.exists(onnx_path)) | ||
| self.assertGreater(os.path.getsize(onnx_path), 0) | ||
|
|
||
| # Verify the ONNX model input is fp64 | ||
| model_onnx = onnx.load(onnx_path) | ||
| input_type = model_onnx.graph.input[0].type.tensor_type.elem_type | ||
| self.assertEqual(input_type, 11) # ONNX TensorProto.DOUBLE = 11 | ||
|
|
||
| # Verify with onnxruntime | ||
| sess_options = ort.SessionOptions() | ||
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL | ||
| session = ort.InferenceSession(onnx_path, sess_options) | ||
| onnx_output = session.run(None, {"input": test_x.numpy()})[0] | ||
|
|
||
| # Compare ONNX output with PyTorch output | ||
| self.assertAllClose(torch.from_numpy(onnx_output), mean_pred, rtol=1e-5, atol=1e-5) | ||
Uh oh!
There was an error while loading. Please reload this page.