|
10 | 10 | from botorch.exceptions.errors import InputDataError, UnsupportedError |
11 | 11 | from botorch.utils.containers import DenseContainer, SliceContainer |
12 | 12 | from botorch.utils.datasets import ( |
| 13 | + ContextualDataset, |
13 | 14 | FixedNoiseDataset, |
14 | 15 | MultiTaskDataset, |
15 | 16 | RankingDataset, |
@@ -335,3 +336,175 @@ def test_multi_task(self): |
335 | 336 | task_feature_index=-1, |
336 | 337 | target_task_value=0, |
337 | 338 | ) |
| 339 | + |
| 340 | + def test_contextual_datasets(self): |
| 341 | + num_contexts = 3 |
| 342 | + feature_names = [f"x_c{i}" for i in range(num_contexts)] |
| 343 | + parameter_decomposition = { |
| 344 | + f"context_{i}": [f"x_c{i}"] for i in range(num_contexts) |
| 345 | + } |
| 346 | + context_buckets = list(parameter_decomposition.keys()) |
| 347 | + context_outcome_list = [f"y:context_{i}" for i in range(num_contexts)] |
| 348 | + metric_decomposition = {f"{c}": [f"y:{c}"] for c in context_buckets} |
| 349 | + |
| 350 | + # test construction of agg outcome |
| 351 | + dataset_list1 = [ |
| 352 | + make_dataset( |
| 353 | + d=1 * num_contexts, |
| 354 | + has_yvar=True, |
| 355 | + feature_names=feature_names, |
| 356 | + outcome_names=["y"], |
| 357 | + ) |
| 358 | + ] |
| 359 | + context_dt = ContextualDataset( |
| 360 | + datasets=dataset_list1, |
| 361 | + parameter_decomposition=parameter_decomposition, |
| 362 | + context_buckets=context_buckets, |
| 363 | + ) |
| 364 | + self.assertEqual(len(context_dt.datasets), len(dataset_list1)) |
| 365 | + self.assertListEqual(context_dt.context_buckets, context_buckets) |
| 366 | + self.assertListEqual(context_dt.outcome_names, ["y"]) |
| 367 | + self.assertListEqual(context_dt.feature_names, feature_names) |
| 368 | + self.assertIs(context_dt.datasets["y"], dataset_list1[0]) |
| 369 | + self.assertIs(context_dt.X, dataset_list1[0].X) |
| 370 | + self.assertIs(context_dt.Y, dataset_list1[0].Y) |
| 371 | + self.assertIs(context_dt.Yvar, dataset_list1[0].Yvar) |
| 372 | + |
| 373 | + # test construction of context outcome |
| 374 | + dataset_list2 = [ |
| 375 | + make_dataset( |
| 376 | + d=1 * num_contexts, |
| 377 | + has_yvar=True, |
| 378 | + feature_names=feature_names, |
| 379 | + outcome_names=[context_outcome_list[0]], |
| 380 | + ) |
| 381 | + ] |
| 382 | + for m in context_outcome_list[1:]: |
| 383 | + dataset_list2.append( |
| 384 | + SupervisedDataset( |
| 385 | + X=dataset_list2[0].X, |
| 386 | + Y=rand(dataset_list2[0].Y.size()), |
| 387 | + Yvar=rand(dataset_list2[0].Yvar.size()), |
| 388 | + feature_names=feature_names, |
| 389 | + outcome_names=[m], |
| 390 | + ) |
| 391 | + ) |
| 392 | + context_dt = ContextualDataset( |
| 393 | + datasets=dataset_list2, |
| 394 | + parameter_decomposition=parameter_decomposition, |
| 395 | + context_buckets=context_buckets, |
| 396 | + metric_decomposition=metric_decomposition, |
| 397 | + ) |
| 398 | + self.assertEqual(len(context_dt.datasets), len(dataset_list2)) |
| 399 | + self.assertListEqual(context_dt.context_buckets, context_buckets) |
| 400 | + self.assertListEqual(context_dt.outcome_names, context_outcome_list) |
| 401 | + self.assertListEqual(context_dt.feature_names, feature_names) |
| 402 | + self.assertTrue(torch.equal(context_dt.X, dataset_list2[-1].X)) |
| 403 | + self.assertEqual(context_dt.Y.shape[-1], len(context_outcome_list)) |
| 404 | + self.assertEqual(context_dt.Yvar.shape[-1], len(context_outcome_list)) |
| 405 | + for dt in dataset_list2: |
| 406 | + self.assertIs(context_dt.datasets[dt.outcome_names[0]], dt) |
| 407 | + |
| 408 | + # test the ordering via context buckets |
| 409 | + context_dt_reverse = ContextualDataset( |
| 410 | + datasets=dataset_list2, |
| 411 | + parameter_decomposition=parameter_decomposition, |
| 412 | + context_buckets=context_buckets[::-1], # reverse order |
| 413 | + metric_decomposition=metric_decomposition, |
| 414 | + ) |
| 415 | + self.assertListEqual( |
| 416 | + context_dt_reverse.outcome_names, context_outcome_list[::-1] |
| 417 | + ) |
| 418 | + self.assertTrue( |
| 419 | + torch.equal(context_dt.Y, torch.flip(context_dt_reverse.Y, (1,))) |
| 420 | + ) |
| 421 | + self.assertTrue( |
| 422 | + torch.equal(context_dt.Yvar, torch.flip(context_dt_reverse.Yvar, (1,))) |
| 423 | + ) |
| 424 | + |
| 425 | + # test dataset validation |
| 426 | + wrong_metric_decomposition = { |
| 427 | + f"{c}": [f"y:{c}"] for c in context_buckets if c != "context_0" |
| 428 | + } |
| 429 | + wrong_metric_decomposition["context_0"] = ["y:context_0", "y:context_1"] |
| 430 | + with self.assertRaisesRegex( |
| 431 | + ValueError, "context_0 bucket contains mutltiple outcomes" |
| 432 | + ): |
| 433 | + ContextualDataset( |
| 434 | + datasets=dataset_list2, |
| 435 | + parameter_decomposition=parameter_decomposition, |
| 436 | + context_buckets=context_buckets, |
| 437 | + metric_decomposition=wrong_metric_decomposition, |
| 438 | + ) |
| 439 | + |
| 440 | + with self.assertRaisesRegex( |
| 441 | + InputDataError, "Require same X for context buckets" |
| 442 | + ): |
| 443 | + ContextualDataset( |
| 444 | + datasets=[ |
| 445 | + make_dataset(d=num_contexts, outcome_names=[m]) |
| 446 | + for m in context_outcome_list |
| 447 | + ], |
| 448 | + parameter_decomposition=parameter_decomposition, |
| 449 | + context_buckets=context_buckets, |
| 450 | + ) |
| 451 | + |
| 452 | + with self.assertRaisesRegex( |
| 453 | + InputDataError, |
| 454 | + "metric_decomposition must be provided when there are multiple datasets.", |
| 455 | + ): |
| 456 | + ContextualDataset( |
| 457 | + datasets=dataset_list2, |
| 458 | + parameter_decomposition=parameter_decomposition, |
| 459 | + context_buckets=context_buckets, |
| 460 | + ) |
| 461 | + |
| 462 | + with self.assertRaisesRegex( |
| 463 | + InputDataError, |
| 464 | + "metric_decomposition is redundant when there is " |
| 465 | + + "one dataset for overall outcome.", |
| 466 | + ): |
| 467 | + ContextualDataset( |
| 468 | + datasets=dataset_list1, |
| 469 | + parameter_decomposition=parameter_decomposition, |
| 470 | + context_buckets=context_buckets, |
| 471 | + metric_decomposition=metric_decomposition, |
| 472 | + ) |
| 473 | + |
| 474 | + with self.assertRaisesRegex( |
| 475 | + InputDataError, |
| 476 | + "Keys of parameter decomposition and context buckets do not match.", |
| 477 | + ): |
| 478 | + ContextualDataset( |
| 479 | + datasets=dataset_list1, |
| 480 | + parameter_decomposition=parameter_decomposition, |
| 481 | + context_buckets=["context_0", "context_1"], |
| 482 | + ) |
| 483 | + |
| 484 | + with self.assertRaisesRegex( |
| 485 | + InputDataError, |
| 486 | + "Keys of metric decomposition and context buckets do not match.", |
| 487 | + ): |
| 488 | + ContextualDataset( |
| 489 | + datasets=dataset_list2, |
| 490 | + parameter_decomposition=parameter_decomposition, |
| 491 | + context_buckets=context_buckets, |
| 492 | + metric_decomposition={ |
| 493 | + f"{c}": [f"y:{c}"] for c in context_buckets if c != "context_0" |
| 494 | + }, |
| 495 | + ) |
| 496 | + |
| 497 | + wrong_metric_decomposition = { |
| 498 | + f"{c}": [f"y:{c}"] for c in context_buckets if c != "context_0" |
| 499 | + } |
| 500 | + wrong_metric_decomposition["context_0"] = ["wrong_metric"] |
| 501 | + missing_outcome = "y:context_0" |
| 502 | + with self.assertRaisesRegex( |
| 503 | + InputDataError, f"{missing_outcome} is missing in metric_decomposition." |
| 504 | + ): |
| 505 | + ContextualDataset( |
| 506 | + datasets=dataset_list2, |
| 507 | + parameter_decomposition=parameter_decomposition, |
| 508 | + context_buckets=context_buckets, |
| 509 | + metric_decomposition=wrong_metric_decomposition, |
| 510 | + ) |
0 commit comments