@@ -60,10 +60,11 @@ def test_case(request):
6060
6161
6262@pytest .mark .parametrize ("n_times" , range (5 ))
63- def test_compute (n_times , test_case : Tuple [Tensor , Tensor , int ]):
63+ def test_compute (n_times , test_case : Tuple [Tensor , Tensor , int ], available_device ):
6464 y_pred , y , batch_size = test_case
6565
66- js_div = JSDivergence ()
66+ js_div = JSDivergence (device = available_device )
67+ assert js_div ._device == torch .device (available_device )
6768
6869 js_div .reset ()
6970 if batch_size > 1 :
@@ -85,8 +86,9 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]):
8586 assert pytest .approx (np_res , rel = 1e-4 ) == res
8687
8788
88- def test_accumulator_detached ():
89- js_div = JSDivergence ()
89+ def test_accumulator_detached (available_device ):
90+ js_div = JSDivergence (device = available_device )
91+ assert js_div ._device == torch .device (available_device )
9092
9193 y_pred = torch .tensor ([[2.0 , 3.0 ], [- 2.0 , 1.0 ]], dtype = torch .float )
9294 y = torch .tensor ([[- 2.0 , 1.0 ], [2.0 , 3.0 ]], dtype = torch .float )
0 commit comments