@@ -88,6 +88,7 @@ def __init__(
88
88
self ._t0s : dict [RunningStage , float ] = {}
89
89
self ._lengths : dict [RunningStage , int ] = {}
90
90
self ._samples : dict [RunningStage , int ] = {}
91
+ self ._batches : dict [RunningStage , int ] = {}
91
92
92
93
@override
93
94
def setup (self , trainer : "Trainer" , pl_module : "LightningModule" , stage : str ) -> None :
@@ -107,10 +108,14 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
107
108
def _start (self , trainer : "Trainer" ) -> None :
108
109
stage = trainer .state .stage
109
110
assert stage is not None
110
- self ._throughputs [stage ].reset ()
111
- self ._lengths [stage ] = 0
111
+
112
+ if stage not in self ._samples :
113
+ self ._throughputs [stage ].reset ()
114
+ self ._lengths [stage ] = 0
115
+ self ._samples [stage ] = 0
116
+ self ._batches [stage ] = 0
117
+
112
118
self ._t0s [stage ] = time .perf_counter ()
113
- self ._samples [stage ] = 0
114
119
115
120
@torch .inference_mode () # in case `length_fn` or `batch_size_fn` computes grads
116
121
def _update (self , trainer : "Trainer" , pl_module : "LightningModule" , batch : Any , iter_num : int ) -> None :
@@ -136,10 +141,11 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
136
141
flops_per_batch = None
137
142
138
143
self ._samples [stage ] += self .batch_size_fn (batch )
144
+ self ._batches [stage ] += 1
139
145
140
146
throughput .update (
141
147
time = elapsed ,
142
- batches = iter_num ,
148
+ batches = self . _batches [ stage ] ,
143
149
# this assumes that all iterations used the same batch size
144
150
samples = self ._samples [stage ],
145
151
lengths = None if self .length_fn is None else self ._lengths [stage ],
0 commit comments