Skip to content

Commit 9287b35

Browse files
BanzaiTokyovfdev-5
andauthored
adds available_device to test_frequency.py #3335 (#3359)
Co-authored-by: vfdev <[email protected]>
1 parent df5e970 commit 9287b35

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tests/ignite/metrics/test_frequency.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,21 @@
1414

1515

1616
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Skip on Windows")
17-
def test_nondistributed_average():
17+
def test_nondistributed_average(available_device):
1818
artificial_time = 1 # seconds
1919
num_tokens = 100
2020
average_upper_bound = num_tokens / artificial_time
2121
average_lower_bound = average_upper_bound * 0.9
22-
freq_metric = Frequency()
22+
freq_metric = Frequency(device=available_device)
23+
assert freq_metric._device == torch.device(available_device)
2324
freq_metric.reset()
2425
time.sleep(artificial_time)
2526
freq_metric.update(num_tokens)
2627
average = freq_metric.compute()
2728
assert average_lower_bound < average < average_upper_bound
2829

2930

30-
def _test_frequency_with_engine(workers=None, lower_bound_factor=0.8, upper_bound_factor=1.1, every=1):
31+
def _test_frequency_with_engine(workers=None, lower_bound_factor=0.8, upper_bound_factor=1.1, every=1, device="cpu"):
3132
if workers is None:
3233
workers = idist.get_world_size()
3334

@@ -42,7 +43,9 @@ def update_fn(engine, batch):
4243
return {"ntokens": len(batch)}
4344

4445
engine = Engine(update_fn)
45-
wps_metric = Frequency(output_transform=lambda x: x["ntokens"])
46+
wps_metric = Frequency(output_transform=lambda x: x["ntokens"], device=device)
47+
assert wps_metric._device == torch.device(device)
48+
4649
event = Events.ITERATION_COMPLETED(every=every)
4750
wps_metric.attach(engine, "wps", event_name=event)
4851

@@ -63,8 +66,8 @@ def assert_wps(e):
6366

6467

6568
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Skip on Windows")
66-
def test_frequency_with_engine():
67-
_test_frequency_with_engine(workers=1)
69+
def test_frequency_with_engine(available_device):
70+
_test_frequency_with_engine(workers=1, device=available_device)
6871

6972

7073
@pytest.mark.distributed
@@ -73,9 +76,9 @@ def test_frequency_with_engine_distributed(distributed_context_single_node_gloo)
7376
_test_frequency_with_engine(workers=idist.get_world_size())
7477

7578

76-
def test_frequency_with_engine_with_every():
77-
_test_frequency_with_engine(workers=1, every=1)
78-
_test_frequency_with_engine(workers=1, every=10)
79+
def test_frequency_with_engine_with_every(available_device):
80+
_test_frequency_with_engine(workers=1, every=1, device=available_device)
81+
_test_frequency_with_engine(workers=1, every=10, device=available_device)
7982

8083

8184
@pytest.mark.distributed

0 commit comments

Comments
 (0)