|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 |
|
| 6 | +from gpytorch.mlls import VariationalELBO |
6 | 7 | from gpytorch.test.base_test_case import BaseTestCase |
7 | 8 | from gpytorch.variational.large_batch_variational_strategy import LargeBatchVariationalStrategy, QuadFormDiagonal |
8 | 9 | from gpytorch.variational.variational_strategy import VariationalStrategy |
@@ -111,3 +112,206 @@ def test_against_variational_strategy(self, train: bool = True): |
111 | 112 |
|
112 | 113 | def test_against_variational_strategy_eval(self): |
113 | 114 | self.test_against_variational_strategy(train=False) |
| 115 | + |
| 116 | + def test_inference_caching(self): |
| 117 | + """Test that inference caching works correctly.""" |
| 118 | + torch.manual_seed(42) |
| 119 | + |
| 120 | + model, likelihood = self._make_model_and_likelihood( |
| 121 | + batch_shape=self.batch_shape, |
| 122 | + strategy_cls=self.strategy_cls, |
| 123 | + distribution_cls=self.distribution_cls, |
| 124 | + ) |
| 125 | + |
| 126 | + # Training - caches should not exist |
| 127 | + model.train() |
| 128 | + likelihood.train() |
| 129 | + |
| 130 | + train_x = torch.rand(32, 2) |
| 131 | + _ = model(train_x) |
| 132 | + |
| 133 | + # Check that caches don't exist in training mode |
| 134 | + self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) |
| 135 | + self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term")) |
| 136 | + |
| 137 | + # Switch to eval mode |
| 138 | + model.eval() |
| 139 | + likelihood.eval() |
| 140 | + |
| 141 | + # First inference call should create caches |
| 142 | + test_x = torch.rand(5, 2) |
| 143 | + with torch.no_grad(): |
| 144 | + _ = model(test_x) |
| 145 | + |
| 146 | + # Check caches exist |
| 147 | + self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) |
| 148 | + self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term")) |
| 149 | + |
| 150 | + # Second inference call should use cached values |
| 151 | + cached_inv_chol = model.variational_strategy._cached_inv_chol_t_inducing_values |
| 152 | + cached_middle = model.variational_strategy._cached_middle_term |
| 153 | + |
| 154 | + with torch.no_grad(): |
| 155 | + _ = model(test_x) |
| 156 | + |
| 157 | + # Verify caches are the same objects (not recomputed) |
| 158 | + self.assertTrue(model.variational_strategy._cached_inv_chol_t_inducing_values is cached_inv_chol) |
| 159 | + self.assertTrue(model.variational_strategy._cached_middle_term is cached_middle) |
| 160 | + |
| 161 | + # Switching back to train mode should clear caches |
| 162 | + model.train() |
| 163 | + _ = model(test_x) |
| 164 | + self.assertFalse(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) |
| 165 | + self.assertFalse(hasattr(model.variational_strategy, "_cached_middle_term")) |
| 166 | + |
| 167 | + def test_onnx_export(self): |
| 168 | + """Test that a trained SVGP with LargeBatchVariationalStrategy can be exported to ONNX in fp64.""" |
| 169 | + import os |
| 170 | + import tempfile |
| 171 | + |
| 172 | + try: |
| 173 | + import onnx |
| 174 | + import onnxruntime as ort |
| 175 | + except ImportError: |
| 176 | + self.skipTest("onnx and onnxruntime required for this test") |
| 177 | + |
| 178 | + from torch.onnx import register_custom_op_symbolic |
| 179 | + |
| 180 | + # Create and train model in fp64 |
| 181 | + torch.manual_seed(42) |
| 182 | + model, likelihood = self._make_model_and_likelihood( |
| 183 | + batch_shape=self.batch_shape, |
| 184 | + strategy_cls=self.strategy_cls, |
| 185 | + distribution_cls=self.distribution_cls, |
| 186 | + ) |
| 187 | + model = model.double() |
| 188 | + likelihood = likelihood.double() |
| 189 | + |
| 190 | + # Quick training on toy data (fp64) |
| 191 | + train_x = torch.rand(32, 2, dtype=torch.float64) |
| 192 | + train_y = torch.sin(train_x[:, 0]) + 0.1 * torch.randn(32, dtype=torch.float64) |
| 193 | + |
| 194 | + model.train() |
| 195 | + likelihood.train() |
| 196 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) |
| 197 | + |
| 198 | + mll = VariationalELBO(likelihood, model, num_data=train_x.size(0)) |
| 199 | + |
| 200 | + for _ in range(5): |
| 201 | + optimizer.zero_grad() |
| 202 | + output = model(train_x) |
| 203 | + loss = -mll(output, train_y) |
| 204 | + loss.backward() |
| 205 | + optimizer.step() |
| 206 | + |
| 207 | + # Switch to eval mode |
| 208 | + model.eval() |
| 209 | + likelihood.eval() |
| 210 | + |
| 211 | + # Create a wrapper that returns the mean as a tensor (required for ONNX) |
| 212 | + class MeanPredictionWrapper(torch.nn.Module): |
| 213 | + def __init__(self, gp_model): |
| 214 | + super().__init__() |
| 215 | + self.gp_model = gp_model |
| 216 | + |
| 217 | + def forward(self, x): |
| 218 | + output = self.gp_model(x) |
| 219 | + return output.mean |
| 220 | + |
| 221 | + wrapper = MeanPredictionWrapper(model) |
| 222 | + wrapper.eval() |
| 223 | + |
| 224 | + # Do a dummy forward to populate the inference caches |
| 225 | + test_x = torch.rand(5, 2, dtype=torch.float64) |
| 226 | + with torch.no_grad(): |
| 227 | + _ = wrapper(test_x) |
| 228 | + |
| 229 | + # Verify caches are populated and are in fp64 |
| 230 | + self.assertTrue(hasattr(model.variational_strategy, "_cached_inv_chol_t_inducing_values")) |
| 231 | + self.assertTrue(hasattr(model.variational_strategy, "_cached_middle_term")) |
| 232 | + self.assertEqual(model.variational_strategy._cached_inv_chol_t_inducing_values.dtype, torch.float64) |
| 233 | + self.assertEqual(model.variational_strategy._cached_middle_term.dtype, torch.float64) |
| 234 | + |
| 235 | + # Get reference output for comparison |
| 236 | + with torch.no_grad(): |
| 237 | + mean_pred = wrapper(test_x) |
| 238 | + self.assertEqual(mean_pred.shape, torch.Size([5])) |
| 239 | + self.assertEqual(mean_pred.dtype, torch.float64) |
| 240 | + |
| 241 | + # Register custom ONNX symbolics for ops not in default registry |
| 242 | + def mT_symbolic(g, self): |
| 243 | + # mT swaps the last two dimensions |
| 244 | + tensor_type = self.type() |
| 245 | + if tensor_type is not None and tensor_type.dim() is not None: |
| 246 | + rank = tensor_type.dim() |
| 247 | + perm = list(range(rank - 2)) + [rank - 1, rank - 2] |
| 248 | + return g.op("Transpose", self, perm_i=perm) |
| 249 | + return g.op("Transpose", self, perm_i=[1, 0]) |
| 250 | + |
| 251 | + def softplus_symbolic(g, self, beta, threshold): |
| 252 | + # Numerically stable Softplus using Where: |
| 253 | + # softplus(x) = x + log(1 + exp(-x)) for x > 0 |
| 254 | + # = log(1 + exp(x)) for x <= 0 |
| 255 | + scaled = g.op("Mul", self, beta) |
| 256 | + zero = g.op("Constant", value_t=torch.tensor([0.0], dtype=torch.float64)) |
| 257 | + one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float64)) |
| 258 | + condition = g.op("Greater", scaled, zero) |
| 259 | + |
| 260 | + neg_scaled = g.op("Neg", scaled) |
| 261 | + exp_neg = g.op("Exp", neg_scaled) |
| 262 | + one_plus_exp_neg = g.op("Add", one, exp_neg) |
| 263 | + log_pos = g.op("Log", one_plus_exp_neg) |
| 264 | + result_pos = g.op("Add", scaled, log_pos) |
| 265 | + |
| 266 | + exp_scaled = g.op("Exp", scaled) |
| 267 | + one_plus_exp = g.op("Add", one, exp_scaled) |
| 268 | + result_neg = g.op("Log", one_plus_exp) |
| 269 | + |
| 270 | + stable_result = g.op("Where", condition, result_pos, result_neg) |
| 271 | + return g.op("Div", stable_result, beta) |
| 272 | + |
| 273 | + try: |
| 274 | + register_custom_op_symbolic("aten::mT", mT_symbolic, 17) |
| 275 | + except Exception: |
| 276 | + pass |
| 277 | + try: |
| 278 | + register_custom_op_symbolic("aten::softplus", softplus_symbolic, 17) |
| 279 | + except Exception: |
| 280 | + pass |
| 281 | + |
| 282 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 283 | + onnx_path = os.path.join(tmpdir, "model.onnx") |
| 284 | + |
| 285 | + # Build export kwargs - use legacy exporter if available (PyTorch 2.9+) |
| 286 | + export_kwargs = dict( |
| 287 | + input_names=["input"], |
| 288 | + output_names=["mean"], |
| 289 | + opset_version=17, |
| 290 | + ) |
| 291 | + # dynamo=False forces the legacy TorchScript-based exporter |
| 292 | + if hasattr(torch.onnx.export, "__wrapped__") or torch.__version__ >= "2.9": |
| 293 | + export_kwargs["dynamo"] = False |
| 294 | + |
| 295 | + torch.onnx.export( |
| 296 | + wrapper, |
| 297 | + test_x, |
| 298 | + onnx_path, |
| 299 | + **export_kwargs, |
| 300 | + ) |
| 301 | + |
| 302 | + self.assertTrue(os.path.exists(onnx_path)) |
| 303 | + self.assertGreater(os.path.getsize(onnx_path), 0) |
| 304 | + |
| 305 | + # Verify the ONNX model input is fp64 |
| 306 | + model_onnx = onnx.load(onnx_path) |
| 307 | + input_type = model_onnx.graph.input[0].type.tensor_type.elem_type |
| 308 | + self.assertEqual(input_type, 11) # ONNX TensorProto.DOUBLE = 11 |
| 309 | + |
| 310 | + # Verify with onnxruntime |
| 311 | + sess_options = ort.SessionOptions() |
| 312 | + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL |
| 313 | + session = ort.InferenceSession(onnx_path, sess_options) |
| 314 | + onnx_output = session.run(None, {"input": test_x.numpy()})[0] |
| 315 | + |
| 316 | + # Compare ONNX output with PyTorch output |
| 317 | + self.assertAllClose(torch.from_numpy(onnx_output), mean_pred, rtol=1e-5, atol=1e-5) |
0 commit comments