@@ -87,6 +87,7 @@ def __init__(
87
87
self ._throughputs : Dict [RunningStage , Throughput ] = {}
88
88
self ._t0s : Dict [RunningStage , float ] = {}
89
89
self ._lengths : Dict [RunningStage , int ] = {}
90
+ self ._samples : Dict [RunningStage , int ] = {}
90
91
91
92
@override
92
93
def setup (self , trainer : "Trainer" , pl_module : "LightningModule" , stage : str ) -> None :
@@ -109,6 +110,7 @@ def _start(self, trainer: "Trainer") -> None:
109
110
self ._throughputs [stage ].reset ()
110
111
self ._lengths [stage ] = 0
111
112
self ._t0s [stage ] = time .perf_counter ()
113
+ self ._samples [stage ] = 0
112
114
113
115
@torch .inference_mode () # in case `length_fn` or `batch_size_fn` computes grads
114
116
def _update (self , trainer : "Trainer" , pl_module : "LightningModule" , batch : Any , iter_num : int ) -> None :
@@ -133,12 +135,13 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
133
135
)
134
136
flops_per_batch = None
135
137
136
- batch_size = self .batch_size_fn (batch )
138
+ self ._samples [stage ] += self .batch_size_fn (batch )
139
+
137
140
throughput .update (
138
141
time = elapsed ,
139
142
batches = iter_num ,
140
143
# this assumes that all iterations used the same batch size
141
- samples = iter_num * batch_size ,
144
+ samples = self . _samples [ stage ] ,
142
145
lengths = None if self .length_fn is None else self ._lengths [stage ],
143
146
flops = flops_per_batch ,
144
147
)
0 commit comments