Skip to content

Commit 55b8911

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Implement ContextualDataset.__eq__ (#3005)
Summary: Pull Request resolved: #3005 -- Reviewed By: esantorella Differential Revision: D81800547 fbshipit-source-id: cb2a50cc13e52cd4d0c3dad262bf6f171f1f29f7
1 parent 3135a7b commit 55b8911

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

botorch/utils/datasets.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,6 @@ def __init__(
601601
self.datasets: dict[str, SupervisedDataset] = {
602602
ds.outcome_names[0]: ds for ds in datasets
603603
}
604-
self.feature_names = datasets[0].feature_names
605-
self.outcome_names = list(self.datasets.keys())
606604
self.parameter_decomposition = parameter_decomposition
607605
self.metric_decomposition = metric_decomposition
608606
self._validate_datasets()
@@ -614,6 +612,14 @@ def __init__(
614612
}
615613
self.group_indices = None
616614

615+
@property
616+
def feature_names(self) -> list[str]:
617+
return self.datasets[self.outcome_names[0]].feature_names
618+
619+
@property
620+
def outcome_names(self) -> list[str]:
621+
return list(self.datasets.keys())
622+
617623
@property
618624
def X(self) -> Tensor:
619625
return self.datasets[self.outcome_names[0]].X
@@ -737,6 +743,14 @@ def _validate_decompositions(self) -> None:
737743
f"{outcome} is missing in metric_decomposition."
738744
)
739745

746+
def __eq__(self, other: Any) -> bool:
747+
return (
748+
type(other) is type(self)
749+
and self.datasets == other.datasets
750+
and self.parameter_decomposition == other.parameter_decomposition
751+
and self.metric_decomposition == other.metric_decomposition
752+
)
753+
740754
def clone(
741755
self, deepcopy: bool = False, mask: Tensor | None = None
742756
) -> ContextualDataset:

test/utils/test_datasets.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,17 @@ def test_clone_contextual_dataset(self):
729729
)
730730
else:
731731
self.assertIsNone(context_dt2.metric_decomposition)
732+
733+
def test_contextual_dataset_equality(self) -> None:
734+
context_dt, _ = make_contextual_dataset(has_yvar=True, contextual_outcome=True)
735+
clone = context_dt.clone()
736+
self.assertEqual(context_dt, clone)
737+
for yvar, outcome in (
738+
(True, False),
739+
(False, True),
740+
(False, False),
741+
):
742+
new_dt, _ = make_contextual_dataset(
743+
has_yvar=yvar, contextual_outcome=outcome
744+
)
745+
self.assertNotEqual(context_dt, new_dt)

0 commit comments

Comments
 (0)