Skip to content

Commit 4bb80b7

Browse files
BanzaiTokyovfdev-5
andauthored
adds available_device to test_js_divergence.py #3335 (#3362)
Co-authored-by: vfdev <[email protected]>
1 parent 3d4f8f6 commit 4bb80b7

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/ignite/metrics/test_js_divergence.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@ def test_case(request):
6060

6161

6262
@pytest.mark.parametrize("n_times", range(5))
63-
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]):
63+
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int], available_device):
6464
y_pred, y, batch_size = test_case
6565

66-
js_div = JSDivergence()
66+
js_div = JSDivergence(device=available_device)
67+
assert js_div._device == torch.device(available_device)
6768

6869
js_div.reset()
6970
if batch_size > 1:
@@ -85,8 +86,9 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]):
8586
assert pytest.approx(np_res, rel=1e-4) == res
8687

8788

88-
def test_accumulator_detached():
89-
js_div = JSDivergence()
89+
def test_accumulator_detached(available_device):
90+
js_div = JSDivergence(device=available_device)
91+
assert js_div._device == torch.device(available_device)
9092

9193
y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
9294
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float)

0 commit comments

Comments
 (0)