88# pyre-strict
99
1010import unittest
11- from unittest .mock import ANY , call , MagicMock
11+ from unittest .mock import ANY , call , MagicMock , patch
1212
1313import torch
1414from pyre_extensions import none_throws
1717from torchtnt .framework ._test_utils import (
1818 DummyAutoUnit ,
1919 DummyPredictUnit ,
20+ DummyTrainUnit ,
2021 generate_random_dataloader ,
2122)
2223from torchtnt .framework .callbacks .throughput_logger import ThroughputLogger
2324from torchtnt .framework .predict import predict
2425
25- from torchtnt .framework .state import EntryPoint , PhaseState , State
26- from torchtnt .framework .train import _train_impl
26+ from torchtnt .framework .state import ActivePhase , EntryPoint , PhaseState , State
27+ from torchtnt .framework .train import _train_impl , train
28+ from torchtnt .framework .unit import TrainUnit
2729from torchtnt .utils .loggers .logger import MetricLogger
2830
2931
@@ -121,21 +123,18 @@ def test_with_comparing_time(self) -> None:
121123 evaluate_every_n_epochs = 2 ,
122124 ),
123125 )
126+ throughput_logger = ThroughputLogger (
127+ logger = logger ,
128+ throughput_per_batch = {"Batches" : 1 , "Queries" : 8 },
129+ log_every_n_steps = 1 ,
130+ )
124131
125132 # we want to be able to compare the logging value to the state, so we need to create state manually and
126133 # call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
127134 _train_impl (
128135 state ,
129136 DummyAutoUnit (module = torch .nn .Linear (2 , 2 )),
130- CallbackHandler (
131- [
132- ThroughputLogger (
133- logger = logger ,
134- throughput_per_batch = {"Batches" : 1 , "Queries" : 8 },
135- log_every_n_steps = 1 ,
136- )
137- ],
138- ),
137+ CallbackHandler ([throughput_logger ]),
139138 )
140139
141140 train_iteration_times = none_throws (
@@ -163,8 +162,8 @@ def test_with_comparing_time(self) -> None:
163162 eval_iteration_times [i ] + eval_twfb_times [i ] for i in range (2 )
164163 ]
165164 self .assertEqual (
166- logger .log .call_count , 12
167- ) # 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2)
165+ logger .log .call_count , 18
166+ ) # steps: 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2). epochs: 4 train (2epoch x 2items). 2 eval (1x2 )
168167 train_batches_step_logs = [
169168 call (
170169 "Train: Batches per second (step granularity)" ,
@@ -197,11 +196,36 @@ def test_with_comparing_time(self) -> None:
197196 )
198197 for i in range (2 )
199198 ]
199+ # for epoch, we test the logged value separately
200+ train_batches_epoch_logs = [
201+ call ("Train: Batches per second (epoch granularity)" , ANY , i )
202+ for i in range (1 , 3 )
203+ ]
204+ train_queries_epoch_logs = [
205+ call ("Train: Queries per second (epoch granularity)" , ANY , i )
206+ for i in range (1 , 3 )
207+ ]
208+ eval_epoch_logs = [
209+ call (
210+ "Eval: Queries per second (epoch granularity)" ,
211+ ANY ,
212+ 1 ,
213+ ),
214+ call (
215+ "Eval: Batches per second (epoch granularity)" ,
216+ ANY ,
217+ 1 ,
218+ ),
219+ ]
220+
200221 logger .log .assert_has_calls (
201222 train_batches_step_logs
202223 + train_queries_step_logs
203224 + eval_batches_step_logs
204- + eval_queries_step_logs ,
225+ + eval_queries_step_logs
226+ + train_batches_epoch_logs
227+ + train_queries_epoch_logs
228+ + eval_epoch_logs ,
205229 any_order = True ,
206230 )
207231
@@ -227,6 +251,79 @@ def test_with_predict(self) -> None:
227251 1 ,
228252 )
229253 ],
254+ [
255+ call (
256+ "Predict: Batches per second (epoch granularity)" ,
257+ ANY ,
258+ 1 ,
259+ )
260+ ],
261+ )
262+
263+ def test_log_for_epoch (self ) -> None :
264+ logger = MagicMock (spec = MetricLogger )
265+ unit = DummyTrainUnit (input_dim = 2 )
266+ throughput_logger = ThroughputLogger (logger , {"Batches" : 1 , "Queries" : 8 })
267+ state = State (entry_point = EntryPoint .TRAIN )
268+
269+ self .assertIsNone (throughput_logger ._epoch_start_times .get (ActivePhase .TRAIN ))
270+ self .assertEqual (throughput_logger ._steps_in_epoch [ActivePhase .TRAIN ], 0 )
271+ with patch .object (throughput_logger , "_maybe_log_for_step" ):
272+ throughput_logger .on_train_step_end (state , unit )
273+ self .assertEqual (throughput_logger ._steps_in_epoch [ActivePhase .TRAIN ], 1 )
274+
275+ with patch ("time.perf_counter" , return_value = 0.5 ):
276+ throughput_logger .on_train_epoch_start (state , MagicMock (spec = TrainUnit ))
277+ self .assertEqual (throughput_logger ._epoch_start_times [ActivePhase .TRAIN ], 0.5 )
278+
279+ throughput_logger ._steps_in_epoch [ActivePhase .TRAIN ] = (
280+ 2 # to assume there were two steps in the epoch
281+ )
282+ logger .log .reset_mock ()
283+ with patch ("time.perf_counter" , return_value = 0.6 ):
284+ throughput_logger ._log_for_epoch (state , epoch_logging_for = 15 )
285+
286+ logger .log .assert_has_calls (
287+ [
288+ call (
289+ "Train: Batches per second (epoch granularity)" ,
290+ (1 * 2 ) / (0.6 - 0.5 ),
291+ 15 ,
292+ ),
293+ call (
294+ "Train: Queries per second (epoch granularity)" ,
295+ (8 * 2 ) / (0.6 - 0.5 ),
296+ 15 ,
297+ ),
298+ ]
299+ )
300+
301+ def test_epoch_logging_time (self ) -> None :
302+ logger = MagicMock (spec = MetricLogger )
303+ throughput_logger = ThroughputLogger (logger , {"Queries" : 4 })
304+ with patch ("time.perf_counter" , side_effect = [0.1 , 0.5 , 0.8 , 1.5 ]):
305+ train (
306+ DummyTrainUnit (input_dim = 2 ),
307+ generate_random_dataloader (num_samples = 16 , input_dim = 2 , batch_size = 4 ),
308+ max_epochs = 2 ,
309+ max_steps_per_epoch = 2 ,
310+ callbacks = [throughput_logger ],
311+ )
312+
313+ logger .log .assert_has_calls (
314+ [
315+ call (
316+ "Train: Queries per second (epoch granularity)" ,
317+ (4 * 2 ) / (0.5 - 0.1 ),
318+ 1 ,
319+ ),
320+ call (
321+ "Train: Queries per second (epoch granularity)" ,
322+ (4 * 2 ) / (1.5 - 0.8 ),
323+ 2 ,
324+ ),
325+ ],
326+ any_order = True ,
230327 )
231328
232329 def test_input_validation (self ) -> None :
0 commit comments