@@ -158,7 +158,9 @@ def test_q_max_value_entropy(self):
158158
159159 # Test with multi-output model w/ transform.
160160 mm = MESMockModel (num_outputs = 2 )
161- pt = ScalarizedPosteriorTransform (weights = torch .ones (2 , dtype = dtype ))
161+ pt = ScalarizedPosteriorTransform (
162+ weights = torch .ones (2 , device = self .device , dtype = dtype )
163+ )
162164 for gumbel in (True , False ):
163165 qMVE = qMaxValueEntropy (
164166 mm ,
@@ -225,7 +227,9 @@ def test_q_lower_bound_max_value_entropy(self):
225227
226228 # Test with multi-output model w/ transform.
227229 mm = MESMockModel (num_outputs = 2 )
228- pt = ScalarizedPosteriorTransform (weights = torch .ones (2 , dtype = dtype ))
230+ pt = ScalarizedPosteriorTransform (
231+ weights = torch .ones (2 , device = self .device , dtype = dtype )
232+ )
229233 qGIBBON = qLowerBoundMaxValueEntropy (
230234 mm ,
231235 candidate_set ,
@@ -270,7 +274,9 @@ def test_q_multi_fidelity_max_value_entropy(self):
270274
271275 # Test with multi-output model w/ transform.
272276 mm = MESMockModel (num_outputs = 2 )
273- pt = ScalarizedPosteriorTransform (weights = torch .ones (2 , dtype = dtype ))
277+ pt = ScalarizedPosteriorTransform (
278+ weights = torch .ones (2 , device = self .device , dtype = dtype )
279+ )
274280 qMF_MVE = qMultiFidelityMaxValueEntropy (
275281 model = mm ,
276282 candidate_set = candidate_set ,
@@ -284,13 +290,15 @@ def test_sample_max_value_Gumbel(self):
284290 for dtype in (torch .float , torch .double ):
285291 torch .manual_seed (7 )
286292 mm = MESMockModel ()
287- candidate_set = torch .rand (3 , 10 , 2 , dtype = dtype )
293+ candidate_set = torch .rand (3 , 10 , 2 , device = self . device , dtype = dtype )
288294 samples = _sample_max_value_Gumbel (mm , candidate_set , 5 )
289295 self .assertEqual (samples .shape , torch .Size ([5 , 3 ]))
290296
291297 # Test with multi-output model w/ transform.
292298 mm = MESMockModel (num_outputs = 2 )
293- pt = ScalarizedPosteriorTransform (weights = torch .ones (2 , dtype = dtype ))
299+ pt = ScalarizedPosteriorTransform (
300+ weights = torch .ones (2 , device = self .device , dtype = dtype )
301+ )
294302 samples = _sample_max_value_Gumbel (
295303 mm , candidate_set , 5 , posterior_transform = pt
296304 )
@@ -300,13 +308,15 @@ def test_sample_max_value_Thompson(self):
300308 for dtype in (torch .float , torch .double ):
301309 torch .manual_seed (7 )
302310 mm = MESMockModel ()
303- candidate_set = torch .rand (3 , 10 , 2 , dtype = dtype )
311+ candidate_set = torch .rand (3 , 10 , 2 , device = self . device , dtype = dtype )
304312 samples = _sample_max_value_Thompson (mm , candidate_set , 5 )
305313 self .assertEqual (samples .shape , torch .Size ([5 , 3 ]))
306314
307315 # Test with multi-output model w/ transform.
308316 mm = MESMockModel (num_outputs = 2 )
309- pt = ScalarizedPosteriorTransform (weights = torch .ones (2 , dtype = dtype ))
317+ pt = ScalarizedPosteriorTransform (
318+ weights = torch .ones (2 , device = self .device , dtype = dtype )
319+ )
310320 samples = _sample_max_value_Thompson (
311321 mm , candidate_set , 5 , posterior_transform = pt
312322 )
0 commit comments