1414
1515
1616@pytest .mark .skipif (sys .platform .startswith ("win" ), reason = "Skip on Windows" )
17- def test_nondistributed_average ():
17+ def test_nondistributed_average (available_device ):
1818 artificial_time = 1 # seconds
1919 num_tokens = 100
2020 average_upper_bound = num_tokens / artificial_time
2121 average_lower_bound = average_upper_bound * 0.9
22- freq_metric = Frequency ()
22+ freq_metric = Frequency (device = available_device )
23+ assert freq_metric ._device == torch .device (available_device )
2324 freq_metric .reset ()
2425 time .sleep (artificial_time )
2526 freq_metric .update (num_tokens )
2627 average = freq_metric .compute ()
2728 assert average_lower_bound < average < average_upper_bound
2829
2930
30- def _test_frequency_with_engine (workers = None , lower_bound_factor = 0.8 , upper_bound_factor = 1.1 , every = 1 ):
31+ def _test_frequency_with_engine (workers = None , lower_bound_factor = 0.8 , upper_bound_factor = 1.1 , every = 1 , device = "cpu" ):
3132 if workers is None :
3233 workers = idist .get_world_size ()
3334
@@ -42,7 +43,9 @@ def update_fn(engine, batch):
4243 return {"ntokens" : len (batch )}
4344
4445 engine = Engine (update_fn )
45- wps_metric = Frequency (output_transform = lambda x : x ["ntokens" ])
46+ wps_metric = Frequency (output_transform = lambda x : x ["ntokens" ], device = device )
47+ assert wps_metric ._device == torch .device (device )
48+
4649 event = Events .ITERATION_COMPLETED (every = every )
4750 wps_metric .attach (engine , "wps" , event_name = event )
4851
@@ -63,8 +66,8 @@ def assert_wps(e):
6366
6467
6568@pytest .mark .skipif (sys .platform .startswith ("win" ), reason = "Skip on Windows" )
66- def test_frequency_with_engine ():
67- _test_frequency_with_engine (workers = 1 )
69+ def test_frequency_with_engine (available_device ):
70+ _test_frequency_with_engine (workers = 1 , device = available_device )
6871
6972
7073@pytest .mark .distributed
@@ -73,9 +76,9 @@ def test_frequency_with_engine_distributed(distributed_context_single_node_gloo)
7376 _test_frequency_with_engine (workers = idist .get_world_size ())
7477
7578
76- def test_frequency_with_engine_with_every ():
77- _test_frequency_with_engine (workers = 1 , every = 1 )
78- _test_frequency_with_engine (workers = 1 , every = 10 )
79+ def test_frequency_with_engine_with_every (available_device ):
80+ _test_frequency_with_engine (workers = 1 , every = 1 , device = available_device )
81+ _test_frequency_with_engine (workers = 1 , every = 10 , device = available_device )
7982
8083
8184@pytest .mark .distributed
0 commit comments