Skip to content

Commit 79a5d28

Browse files
authored
adds available_device to test_maximum_mean_discrepancy #3335 (#3365)
1 parent 4c24f61 commit 79a5d28

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/ignite/metrics/test_maximum_mean_discrepancy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ def test_case(request):
7070

7171

7272
@pytest.mark.parametrize("n_times", range(5))
73-
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]):
73+
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int], available_device):
7474
x, y, var, batch_size = test_case
7575

76-
mmd = MaximumMeanDiscrepancy(var=var)
76+
mmd = MaximumMeanDiscrepancy(var=var, device=available_device)
77+
assert mmd._device == torch.device(available_device)
7778
mmd.reset()
7879

7980
if batch_size > 1:
@@ -97,8 +98,9 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]):
9798
assert pytest.approx(np_res, abs=1e-4) == res
9899

99100

100-
def test_accumulator_detached():
101-
mmd = MaximumMeanDiscrepancy()
101+
def test_accumulator_detached(available_device):
102+
mmd = MaximumMeanDiscrepancy(device=available_device)
103+
assert mmd._device == torch.device(available_device)
102104

103105
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
104106
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float)

0 commit comments

Comments
 (0)