|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +from collections import defaultdict |
14 | 15 | from unittest import mock
|
15 | 16 | from unittest.mock import DEFAULT, Mock
|
16 | 17 |
|
@@ -201,3 +202,68 @@ def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
|
201 | 202 |
|
202 | 203 | trainer.fit(model)
|
203 | 204 | assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)
|
| 205 | + |
| 206 | + |
| 207 | +@RunIf(rich=True) |
| 208 | +def test_rich_progress_bar_correct_value_epoch_end(tmpdir): |
| 209 | + """Rich counterpart to test_tqdm_progress_bar::test_tqdm_progress_bar_correct_value_epoch_end.""" |
| 210 | + |
| 211 | + class MockedProgressBar(RichProgressBar): |
| 212 | + calls = defaultdict(list) |
| 213 | + |
| 214 | + def get_metrics(self, trainer, pl_module): |
| 215 | + items = super().get_metrics(trainer, model) |
| 216 | + del items["v_num"] |
| 217 | + del items["loss"] |
| 218 | + # this is equivalent to mocking `set_postfix` as this method gets called every time |
| 219 | + self.calls[trainer.state.fn].append( |
| 220 | + (trainer.state.stage, trainer.current_epoch, trainer.global_step, items) |
| 221 | + ) |
| 222 | + return items |
| 223 | + |
| 224 | + class MyModel(BoringModel): |
| 225 | + def training_step(self, batch, batch_idx): |
| 226 | + self.log("a", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) |
| 227 | + return super().training_step(batch, batch_idx) |
| 228 | + |
| 229 | + def validation_step(self, batch, batch_idx): |
| 230 | + self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) |
| 231 | + return super().validation_step(batch, batch_idx) |
| 232 | + |
| 233 | + def test_step(self, batch, batch_idx): |
| 234 | + self.log("c", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) |
| 235 | + return super().test_step(batch, batch_idx) |
| 236 | + |
| 237 | + model = MyModel() |
| 238 | + pbar = MockedProgressBar() |
| 239 | + trainer = Trainer( |
| 240 | + default_root_dir=tmpdir, |
| 241 | + limit_train_batches=2, |
| 242 | + limit_val_batches=2, |
| 243 | + limit_test_batches=2, |
| 244 | + max_epochs=2, |
| 245 | + enable_model_summary=False, |
| 246 | + enable_checkpointing=False, |
| 247 | + log_every_n_steps=1, |
| 248 | + callbacks=pbar, |
| 249 | + ) |
| 250 | + |
| 251 | + trainer.fit(model) |
| 252 | + assert pbar.calls["fit"] == [ |
| 253 | + ("sanity_check", 0, 0, {"b": 0}), |
| 254 | + ("train", 0, 0, {}), |
| 255 | + ("train", 0, 1, {}), |
| 256 | + ("validate", 0, 1, {"b": 1}), # validation end |
| 257 | + # epoch end over, `on_epoch=True` metrics are computed |
| 258 | + ("train", 0, 2, {"a": 1, "b": 1}), # training epoch end |
| 259 | + ("train", 1, 2, {"a": 1, "b": 1}), |
| 260 | + ("train", 1, 3, {"a": 1, "b": 1}), |
| 261 | + ("validate", 1, 3, {"a": 1, "b": 3}), # validation end |
| 262 | + ("train", 1, 4, {"a": 3, "b": 3}), # training epoch end |
| 263 | + ] |
| 264 | + |
| 265 | + trainer.validate(model, verbose=False) |
| 266 | + assert pbar.calls["validate"] == [] |
| 267 | + |
| 268 | + trainer.test(model, verbose=False) |
| 269 | + assert pbar.calls["test"] == [] |
0 commit comments