Skip to content

Commit 214cbec

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fix cuda tests (#1357)
Summary: Pull Request resolved: #1357 These slipped through. Reviewed By: saitcakmak Differential Revision: D38896631 fbshipit-source-id: afe1e4e7d394113bd5ab6de235c78256b40caf24
1 parent 62d1088 commit 214cbec

File tree

5 files changed

+10
-12
lines changed

5 files changed

+10
-12
lines changed

botorch/acquisition/multi_objective/multi_output_risk_measures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def __init__(
557557
apply feasibility-weighting to samples.
558558
"""
559559
super().__init__(alpha=alpha, n_w=n_w)
560-
self.chebyshev_weights = chebyshev_weights
560+
self.chebyshev_weights = torch.as_tensor(chebyshev_weights)
561561
self.baseline_Y = baseline_Y
562562
self.register_buffer(
563563
"ref_point", torch.as_tensor(ref_point) if ref_point is not None else None
@@ -661,7 +661,7 @@ def chebyshev_obj(Y: Tensor, X: Optional[Tensor] = None) -> Tensor:
661661
Y = normalize(Y, bounds=Y_bounds)
662662
if ref_point is not None:
663663
Y = Y - ref_point
664-
product = torch.einsum("...m,m->...m", Y, self.chebyshev_weights)
664+
product = torch.einsum("...m,m->...m", Y, self.chebyshev_weights.to(Y))
665665
return product.min(dim=-1).values
666666

667667
self._chebyshev_objective = chebyshev_obj

test/acquisition/multi_objective/test_multi_output_risk_measures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,4 +614,4 @@ def test_end_to_end(self):
614614
mars_vals = mars(samples)
615615
self.assertEqual(mars_vals.shape, torch.Size([5, 3]))
616616
self.assertEqual(mars_vals.dtype, dtype)
617-
self.assertEqual(mars_vals.device, self.device)
617+
self.assertEqual(mars_vals.device.type, self.device.type)

test/acquisition/test_proximal.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_proximal(self):
131131
self.assertTrue(
132132
torch.allclose(qei_prox, qei * test_prox_weight.flatten())
133133
)
134-
self.assertTrue(qei_prox.shape == torch.Size([4]))
134+
self.assertEqual(qei_prox.shape, torch.Size([4]))
135135

136136
# test gradient
137137
test_X = torch.rand(
@@ -185,14 +185,12 @@ def test_proximal(self):
185185

186186
multi_output_model = SingleTaskGP(train_X, train_Y).to(device=self.device)
187187
ptransform = ScalarizedPosteriorTransform(
188-
weights=torch.ones(2, dtype=dtype)
188+
weights=torch.ones(2, dtype=dtype, device=self.device)
189189
)
190-
acq = ProximalAcquisitionFunction(
191-
ExpectedImprovement(
192-
multi_output_model, 0.0, posterior_transform=ptransform
193-
),
194-
proximal_weights,
190+
ei = ExpectedImprovement(
191+
multi_output_model, 0.0, posterior_transform=ptransform
195192
)
193+
acq = ProximalAcquisitionFunction(ei, proximal_weights)
196194
acq(test_X)
197195

198196
def test_proximal_model_list(self):

test/generation/test_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def test_gen_candidates_scipy_warns_opt_no_res(self):
238238
with mock.patch(
239239
"botorch.generation.gen.minimize"
240240
) as mock_minimize, warnings.catch_warnings(record=True) as ws:
241-
mock_minimize.return_value = OptimizeResult(x=test_ics.numpy())
241+
mock_minimize.return_value = OptimizeResult(x=test_ics.cpu().numpy())
242242

243243
gen_candidates_scipy(
244244
initial_conditions=test_ics,

test/models/test_model_list_gp_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None:
390390
FixedNoiseGP(X, Y, yvar, outcome_transform=Standardize(m=1))
391391
)
392392

393-
model.posterior(torch.zeros((1, 1)))
393+
model.posterior(torch.zeros((1, 1), **tkwargs))
394394

395395
fant = model.fantasize(
396396
X, sampler=IIDNormalSampler(n_fants, seed=0), noise=yvar

0 commit comments

Comments
 (0)