Skip to content

Commit 3d4f8f6

Browse files
BanzaiTokyovfdev-5
andauthored
adds available_device to test_hsic.py #3335 (#3361)
Co-authored-by: vfdev <[email protected]>
1 parent 9287b35 commit 3d4f8f6

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/ignite/metrics/test_hsic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ def test_case(request) -> Tuple[Tensor, Tensor, int]:
8787
@pytest.mark.parametrize("n_times", range(3))
8888
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
8989
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
90-
def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int]):
90+
def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int], available_device):
9191
x, y, batch_size = test_case
9292

93-
hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y)
93+
hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=available_device)
94+
assert hsic._device == torch.device(available_device)
9495

9596
hsic.reset()
9697

@@ -109,8 +110,9 @@ def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tenso
109110
assert pytest.approx(expected_hsic, abs=2e-5) == hsic.compute()
110111

111112

112-
def test_accumulator_detached():
113-
hsic = HSIC()
113+
def test_accumulator_detached(available_device):
114+
hsic = HSIC(device=available_device)
115+
assert hsic._device == torch.device(available_device)
114116

115117
x = torch.rand(10, 10, dtype=torch.float)
116118
y = torch.rand(10, 10, dtype=torch.float)

0 commit comments

Comments
 (0)