Skip to content

Commit 18a3b26

Browse files
dme65facebook-github-bot
authored andcommitted
Fix broken cuda tests (#1115)
Summary: Pull Request resolved: #1115 See title Reviewed By: saitcakmak Differential Revision: D34761452 fbshipit-source-id: b29f6d86d3aaa8b8b16e6c660d8d8d7dac65758d
1 parent 21b4616 commit 18a3b26

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

botorch/models/multitask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def __init__(
473473
if task_covar_prior is None:
474474
task_covar_prior = LKJCovariancePrior(
475475
n=num_tasks,
476-
eta=kwargs.get("eta", 1.5),
476+
eta=torch.tensor(kwargs.get("eta", 1.5)).to(train_X),
477477
sd_prior=kwargs.get(
478478
"sd_prior",
479479
SmoothedBoxPrior(math.exp(-6), math.exp(1.25), 0.05),

test/acquisition/multi_objective/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def test_prune_inferior_points_multi_objective(self):
129129
)
130130
if self.device.type == "cuda":
131131
# sorting has different order on cuda
132-
self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
132+
self.assertTrue(
133+
torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]])
134+
)
133135
else:
134136
self.assertTrue(torch.equal(X_pruned, X[:2]))
135137
# test that zero-probability is in fact pruned

test/acquisition/test_max_value_entropy_search.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def test_q_max_value_entropy(self):
158158

159159
# Test with multi-output model w/ transform.
160160
mm = MESMockModel(num_outputs=2)
161-
pt = ScalarizedPosteriorTransform(weights=torch.ones(2, dtype=dtype))
161+
pt = ScalarizedPosteriorTransform(
162+
weights=torch.ones(2, device=self.device, dtype=dtype)
163+
)
162164
for gumbel in (True, False):
163165
qMVE = qMaxValueEntropy(
164166
mm,
@@ -225,7 +227,9 @@ def test_q_lower_bound_max_value_entropy(self):
225227

226228
# Test with multi-output model w/ transform.
227229
mm = MESMockModel(num_outputs=2)
228-
pt = ScalarizedPosteriorTransform(weights=torch.ones(2, dtype=dtype))
230+
pt = ScalarizedPosteriorTransform(
231+
weights=torch.ones(2, device=self.device, dtype=dtype)
232+
)
229233
qGIBBON = qLowerBoundMaxValueEntropy(
230234
mm,
231235
candidate_set,
@@ -270,7 +274,9 @@ def test_q_multi_fidelity_max_value_entropy(self):
270274

271275
# Test with multi-output model w/ transform.
272276
mm = MESMockModel(num_outputs=2)
273-
pt = ScalarizedPosteriorTransform(weights=torch.ones(2, dtype=dtype))
277+
pt = ScalarizedPosteriorTransform(
278+
weights=torch.ones(2, device=self.device, dtype=dtype)
279+
)
274280
qMF_MVE = qMultiFidelityMaxValueEntropy(
275281
model=mm,
276282
candidate_set=candidate_set,
@@ -284,13 +290,15 @@ def test_sample_max_value_Gumbel(self):
284290
for dtype in (torch.float, torch.double):
285291
torch.manual_seed(7)
286292
mm = MESMockModel()
287-
candidate_set = torch.rand(3, 10, 2, dtype=dtype)
293+
candidate_set = torch.rand(3, 10, 2, device=self.device, dtype=dtype)
288294
samples = _sample_max_value_Gumbel(mm, candidate_set, 5)
289295
self.assertEqual(samples.shape, torch.Size([5, 3]))
290296

291297
# Test with multi-output model w/ transform.
292298
mm = MESMockModel(num_outputs=2)
293-
pt = ScalarizedPosteriorTransform(weights=torch.ones(2, dtype=dtype))
299+
pt = ScalarizedPosteriorTransform(
300+
weights=torch.ones(2, device=self.device, dtype=dtype)
301+
)
294302
samples = _sample_max_value_Gumbel(
295303
mm, candidate_set, 5, posterior_transform=pt
296304
)
@@ -300,13 +308,15 @@ def test_sample_max_value_Thompson(self):
300308
for dtype in (torch.float, torch.double):
301309
torch.manual_seed(7)
302310
mm = MESMockModel()
303-
candidate_set = torch.rand(3, 10, 2, dtype=dtype)
311+
candidate_set = torch.rand(3, 10, 2, device=self.device, dtype=dtype)
304312
samples = _sample_max_value_Thompson(mm, candidate_set, 5)
305313
self.assertEqual(samples.shape, torch.Size([5, 3]))
306314

307315
# Test with multi-output model w/ transform.
308316
mm = MESMockModel(num_outputs=2)
309-
pt = ScalarizedPosteriorTransform(weights=torch.ones(2, dtype=dtype))
317+
pt = ScalarizedPosteriorTransform(
318+
weights=torch.ones(2, device=self.device, dtype=dtype)
319+
)
310320
samples = _sample_max_value_Thompson(
311321
mm, candidate_set, 5, posterior_transform=pt
312322
)

test/models/test_multitask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def test_KroneckerMultiTaskGP_custom(self):
706706
)
707707
task_covar_prior = LKJCovariancePrior(
708708
n=2,
709-
eta=0.5,
709+
eta=torch.tensor(0.5, **tkwargs),
710710
sd_prior=SmoothedBoxPrior(math.exp(-3), math.exp(2), 0.1),
711711
)
712712
model_kwargs = {

0 commit comments

Comments
 (0)