@@ -70,10 +70,11 @@ def test_case(request):
7070
7171
7272@pytest .mark .parametrize ("n_times" , range (5 ))
73- def test_compute (n_times , test_case : Tuple [Tensor , Tensor , float , int ]):
73+ def test_compute (n_times , test_case : Tuple [Tensor , Tensor , float , int ], available_device ):
7474 x , y , var , batch_size = test_case
7575
76- mmd = MaximumMeanDiscrepancy (var = var )
76+ mmd = MaximumMeanDiscrepancy (var = var , device = available_device )
77+ assert mmd ._device == torch .device (available_device )
7778 mmd .reset ()
7879
7980 if batch_size > 1 :
@@ -97,8 +98,9 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]):
9798 assert pytest .approx (np_res , abs = 1e-4 ) == res
9899
99100
100- def test_accumulator_detached ():
101- mmd = MaximumMeanDiscrepancy ()
101+ def test_accumulator_detached (available_device ):
102+ mmd = MaximumMeanDiscrepancy (device = available_device )
103+ assert mmd ._device == torch .device (available_device )
102104
103105 x = torch .tensor ([[2.0 , 3.0 ], [- 2.0 , 1.0 ]], dtype = torch .float )
104106 y = torch .tensor ([[- 2.0 , 1.0 ], [2.0 , 3.0 ]], dtype = torch .float )
0 commit comments