@@ -87,10 +87,11 @@ def test_case(request) -> Tuple[Tensor, Tensor, int]:
8787@pytest .mark .parametrize ("n_times" , range (3 ))
8888@pytest .mark .parametrize ("sigma_x" , [- 1.0 , 1.0 ])
8989@pytest .mark .parametrize ("sigma_y" , [- 1.0 , 1.0 ])
90- def test_compute (n_times , sigma_x : float , sigma_y : float , test_case : Tuple [Tensor , Tensor , int ]):
90+ def test_compute (n_times , sigma_x : float , sigma_y : float , test_case : Tuple [Tensor , Tensor , int ], available_device ):
9191 x , y , batch_size = test_case
9292
93- hsic = HSIC (sigma_x = sigma_x , sigma_y = sigma_y )
93+ hsic = HSIC (sigma_x = sigma_x , sigma_y = sigma_y , device = available_device )
94+ assert hsic ._device == torch .device (available_device )
9495
9596 hsic .reset ()
9697
@@ -109,8 +110,9 @@ def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tenso
109110 assert pytest .approx (expected_hsic , abs = 2e-5 ) == hsic .compute ()
110111
111112
112- def test_accumulator_detached ():
113- hsic = HSIC ()
113+ def test_accumulator_detached (available_device ):
114+ hsic = HSIC (device = available_device )
115+ assert hsic ._device == torch .device (available_device )
114116
115117 x = torch .rand (10 , 10 , dtype = torch .float )
116118 y = torch .rand (10 , 10 , dtype = torch .float )
0 commit comments