Skip to content

Commit 01b2503

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix dataset equality checks (#2077)
Summary: Pull Request resolved: #2077 This was erroring out when self.Yvar was a tensor but other.Yvar was None. Updated to equality checks to handle this case. Reviewed By: Balandat, esantorella Differential Revision: D50818635 fbshipit-source-id: 024e2c12160137202ea1deee2d96e165ef816afc
1 parent 8cdc595 commit 01b2503

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

botorch/utils/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __eq__(self, other: Any) -> bool:
144144
and (
145145
other.Yvar is None
146146
if self.Yvar is None
147-
else torch.equal(self.Yvar, other.Yvar)
147+
else other.Yvar is not None and torch.equal(self.Yvar, other.Yvar)
148148
)
149149
and self.feature_names == other.feature_names
150150
and self.outcome_names == other.outcome_names

test/utils/test_datasets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ def test_supervised(self):
118118
self.assertIsInstance(dataset.Yvar, Tensor)
119119
self.assertIsInstance(dataset._Yvar, DenseContainer)
120120

121+
# More equality checks with & without Yvar.
122+
self.assertEqual(dataset, dataset)
123+
self.assertNotEqual(dataset, dataset2)
124+
self.assertNotEqual(dataset2, dataset)
125+
121126
def test_fixedNoise(self):
122127
# Generate some data
123128
X = rand(3, 2)

0 commit comments

Comments
 (0)