|
7 | 7 | import math |
8 | 8 |
|
9 | 9 | import torch |
| 10 | +from botorch.acquisition import qAnalyticProbabilityOfImprovement |
10 | 11 | from botorch.acquisition.analytic import ( |
11 | 12 | _compute_log_prob_feas, |
12 | 13 | _ei_helper, |
@@ -362,6 +363,161 @@ def test_probability_of_improvement_batch(self): |
362 | 363 | LogProbabilityOfImprovement(model=mm2, best_f=0.0) |
363 | 364 |
|
364 | 365 |
|
| 366 | +class TestqAnalyticProbabilityOfImprovement(BotorchTestCase): |
| 367 | + def test_q_analytic_probability_of_improvement(self): |
| 368 | + for dtype in (torch.float, torch.double): |
| 369 | + mean = torch.zeros(1, device=self.device, dtype=dtype) |
| 370 | + cov = torch.eye(n=1, device=self.device, dtype=dtype) |
| 371 | + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) |
| 372 | + posterior = GPyTorchPosterior(mvn) |
| 373 | + mm = MockModel(posterior) |
| 374 | + |
| 375 | + # basic test |
| 376 | + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96) |
| 377 | + X = torch.rand(1, 2, device=self.device, dtype=dtype) |
| 378 | + pi = module(X) |
| 379 | + pi_expected = torch.tensor(0.0250, device=self.device, dtype=dtype) |
| 380 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 381 | + |
| 382 | + # basic test, maximize |
| 383 | + module = qAnalyticProbabilityOfImprovement( |
| 384 | + model=mm, best_f=1.96, maximize=False |
| 385 | + ) |
| 386 | + X = torch.rand(1, 2, device=self.device, dtype=dtype) |
| 387 | + pi = module(X) |
| 388 | + pi_expected = torch.tensor(0.9750, device=self.device, dtype=dtype) |
| 389 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 390 | + |
| 391 | + # basic test, posterior transform (single-output) |
| 392 | + mean = torch.ones(1, device=self.device, dtype=dtype) |
| 393 | + cov = torch.eye(n=1, device=self.device, dtype=dtype) |
| 394 | + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) |
| 395 | + posterior = GPyTorchPosterior(mvn) |
| 396 | + mm = MockModel(posterior) |
| 397 | + weights = torch.tensor([0.5], device=self.device, dtype=dtype) |
| 398 | + transform = ScalarizedPosteriorTransform(weights) |
| 399 | + module = qAnalyticProbabilityOfImprovement( |
| 400 | + model=mm, best_f=0.0, posterior_transform=transform |
| 401 | + ) |
| 402 | + X = torch.rand(1, 2, device=self.device, dtype=dtype) |
| 403 | + pi = module(X) |
| 404 | + pi_expected = torch.tensor(0.8413, device=self.device, dtype=dtype) |
| 405 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 406 | + |
| 407 | + # basic test, posterior transform (multi-output) |
| 408 | + mean = torch.ones(1, 2, device=self.device, dtype=dtype) |
| 409 | + cov = torch.eye(n=2, device=self.device, dtype=dtype).unsqueeze(0) |
| 410 | + mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) |
| 411 | + posterior = GPyTorchPosterior(mvn) |
| 412 | + mm = MockModel(posterior) |
| 413 | + weights = torch.ones(2, device=self.device, dtype=dtype) |
| 414 | + transform = ScalarizedPosteriorTransform(weights) |
| 415 | + module = qAnalyticProbabilityOfImprovement( |
| 416 | + model=mm, best_f=0.0, posterior_transform=transform |
| 417 | + ) |
| 418 | + X = torch.rand(1, 1, device=self.device, dtype=dtype) |
| 419 | + pi = module(X) |
| 420 | + pi_expected = torch.tensor(0.9214, device=self.device, dtype=dtype) |
| 421 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 422 | + |
| 423 | + # basic test, q = 2 |
| 424 | + mean = torch.zeros(2, device=self.device, dtype=dtype) |
| 425 | + cov = torch.eye(n=2, device=self.device, dtype=dtype) |
| 426 | + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) |
| 427 | + posterior = GPyTorchPosterior(mvn) |
| 428 | + mm = MockModel(posterior) |
| 429 | + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96) |
| 430 | + X = torch.zeros(2, 2, device=self.device, dtype=dtype) |
| 431 | + pi = module(X) |
| 432 | + pi_expected = torch.tensor(0.049375, device=self.device, dtype=dtype) |
| 433 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 434 | + |
| 435 | + def test_batch_q_analytic_probability_of_improvement(self): |
| 436 | + for dtype in (torch.float, torch.double): |
| 437 | + # test batch mode |
| 438 | + mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype) |
| 439 | + cov = ( |
| 440 | + torch.eye(n=1, device=self.device, dtype=dtype) |
| 441 | + .unsqueeze(0) |
| 442 | + .repeat(2, 1, 1) |
| 443 | + ) |
| 444 | + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) |
| 445 | + posterior = GPyTorchPosterior(mvn) |
| 446 | + mm = MockModel(posterior) |
| 447 | + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=0) |
| 448 | + X = torch.rand(2, 1, 1, device=self.device, dtype=dtype) |
| 449 | + pi = module(X) |
| 450 | + pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype) |
| 451 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 452 | + |
| 453 | + # test batched model and best_f values |
| 454 | + mean = torch.zeros(2, 1, device=self.device, dtype=dtype) |
| 455 | + cov = ( |
| 456 | + torch.eye(n=1, device=self.device, dtype=dtype) |
| 457 | + .unsqueeze(0) |
| 458 | + .repeat(2, 1, 1) |
| 459 | + ) |
| 460 | + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) |
| 461 | + posterior = GPyTorchPosterior(mvn) |
| 462 | + mm = MockModel(posterior) |
| 463 | + best_f = torch.tensor([0.0, -1.0], device=self.device, dtype=dtype) |
| 464 | + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=best_f) |
| 465 | + X = torch.rand(2, 1, 1, device=self.device, dtype=dtype) |
| 466 | + pi = module(X) |
| 467 | + pi_expected = torch.tensor([[0.5, 0.8413]], device=self.device, dtype=dtype) |
| 468 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 469 | + |
| 470 | + # test batched model, output transform (single output) |
| 471 | + mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype) |
| 472 | + cov = ( |
| 473 | + torch.eye(n=1, device=self.device, dtype=dtype) |
| 474 | + .unsqueeze(0) |
| 475 | + .repeat(2, 1, 1) |
| 476 | + ) |
| 477 | + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) |
| 478 | + posterior = GPyTorchPosterior(mvn) |
| 479 | + mm = MockModel(posterior) |
| 480 | + weights = torch.tensor([0.5], device=self.device, dtype=dtype) |
| 481 | + transform = ScalarizedPosteriorTransform(weights) |
| 482 | + module = qAnalyticProbabilityOfImprovement( |
| 483 | + model=mm, best_f=0.0, posterior_transform=transform |
| 484 | + ) |
| 485 | + X = torch.rand(2, 1, 2, device=self.device, dtype=dtype) |
| 486 | + pi = module(X) |
| 487 | + pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype) |
| 488 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 489 | + |
| 490 | + # test batched model, output transform (multiple output) |
| 491 | + mean = torch.tensor( |
| 492 | + [[[1.0, 1.0]], [[0.0, 1.0]]], device=self.device, dtype=dtype |
| 493 | + ) |
| 494 | + cov = ( |
| 495 | + torch.eye(n=2, device=self.device, dtype=dtype) |
| 496 | + .unsqueeze(0) |
| 497 | + .repeat(2, 1, 1) |
| 498 | + ) |
| 499 | + mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) |
| 500 | + posterior = GPyTorchPosterior(mvn) |
| 501 | + mm = MockModel(posterior) |
| 502 | + weights = torch.ones(2, device=self.device, dtype=dtype) |
| 503 | + transform = ScalarizedPosteriorTransform(weights) |
| 504 | + module = qAnalyticProbabilityOfImprovement( |
| 505 | + model=mm, best_f=0.0, posterior_transform=transform |
| 506 | + ) |
| 507 | + X = torch.rand(2, 1, 2, device=self.device, dtype=dtype) |
| 508 | + pi = module(X) |
| 509 | + pi_expected = torch.tensor( |
| 510 | + [0.9214, 0.7602], device=self.device, dtype=dtype |
| 511 | + ) |
| 512 | + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) |
| 513 | + |
| 514 | + # test bad posterior transform class |
| 515 | + with self.assertRaises(UnsupportedError): |
| 516 | + qAnalyticProbabilityOfImprovement( |
| 517 | + model=mm, best_f=0.0, posterior_transform=IdentityMCObjective() |
| 518 | + ) |
| 519 | + |
| 520 | + |
365 | 521 | class TestUpperConfidenceBound(BotorchTestCase): |
366 | 522 | def test_upper_confidence_bound(self): |
367 | 523 | for dtype in (torch.float, torch.double): |
|
0 commit comments