Skip to content

Commit ffb7bd9

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fix flaky test_get_best_f_mc test (#1969)
Summary: Pull Request resolved: #1969 This test started to become flaky recently, presumably b/c to some minor changes in the numerics on GPU computations. Replacing equality with closeness check fixes this. Reviewed By: saitcakmak Differential Revision: D48033053 fbshipit-source-id: 5261ec92ca67aef82567c994aeb6ca802a1e07a4
1 parent f20697a commit ffb7bd9

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

test/acquisition/test_input_constructors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,22 @@ def test_get_best_f_mc(self):
140140
best_f = get_best_f_mc(training_data=self.blockX_blockY)
141141
self.assertEqual(best_f, get_best_f_mc(self.blockX_blockY[0]))
142142

143-
best_f_expected = self.blockX_blockY[0].Y().squeeze().max()
144-
self.assertEqual(best_f, best_f_expected)
143+
best_f_expected = self.blockX_blockY[0].Y().max(dim=0).values
144+
self.assertAllClose(best_f, best_f_expected)
145145
with self.assertRaisesRegex(UnsupportedError, "require an objective"):
146146
get_best_f_mc(training_data=self.blockX_multiY)
147147
obj = LinearMCObjective(weights=torch.rand(2))
148148
best_f = get_best_f_mc(training_data=self.blockX_multiY, objective=obj)
149149

150150
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
151151
best_f_expected = (multi_Y @ obj.weights).amax(dim=-1, keepdim=True)
152-
self.assertEqual(best_f, best_f_expected)
152+
self.assertAllClose(best_f, best_f_expected)
153153
post_tf = ScalarizedPosteriorTransform(weights=torch.ones(2))
154154
best_f = get_best_f_mc(
155155
training_data=self.blockX_multiY, posterior_transform=post_tf
156156
)
157157
best_f_expected = (multi_Y.sum(dim=-1)).amax(dim=-1, keepdim=True)
158-
self.assertEqual(best_f, best_f_expected)
158+
self.assertAllClose(best_f, best_f_expected)
159159

160160
@mock.patch("botorch.acquisition.input_constructors.optimize_acqf")
161161
def test_optimize_objective(self, mock_optimize_acqf):

0 commit comments

Comments
 (0)