|
18 | 18 | from contextlib import redirect_stdout
|
19 | 19 | from io import StringIO
|
20 | 20 | from unittest import mock
|
21 |
| -from unittest.mock import call |
| 21 | +from unittest.mock import ANY, call |
22 | 22 |
|
23 | 23 | import numpy as np
|
24 | 24 | import pytest
|
@@ -527,57 +527,47 @@ def test_step(self, batch, batch_idx):
|
527 | 527 | trainer = Trainer(
|
528 | 528 | default_root_dir=tmpdir,
|
529 | 529 | logger=TensorBoardLogger(tmpdir),
|
530 |
| - limit_train_batches=2, |
| 530 | + limit_train_batches=1, |
531 | 531 | limit_val_batches=2,
|
532 | 532 | limit_test_batches=2,
|
| 533 | + log_every_n_steps=1, |
533 | 534 | max_epochs=2,
|
534 | 535 | )
|
535 | 536 |
|
536 | 537 | # Train the model ⚡
|
537 | 538 | trainer.fit(model)
|
538 | 539 |
|
539 |
| - # hp_metric + 2 steps + epoch + 2 steps + epoch |
540 |
| - expected_num_calls = 1 + 2 + 1 + 2 + 1 |
541 |
| - |
542 | 540 | assert set(trainer.callback_metrics) == {
|
543 | 541 | "train_loss",
|
544 | 542 | "valid_loss_0_epoch",
|
545 | 543 | "valid_loss_0",
|
546 | 544 | "valid_loss_1",
|
547 | 545 | }
|
548 |
| - assert len(mock_log_metrics.mock_calls) == expected_num_calls |
549 |
| - assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0) |
| 546 | + assert mock_log_metrics.mock_calls == [ |
| 547 | + call({"hp_metric": -1}, 0), |
| 548 | + call(metrics={"train_loss": ANY, "epoch": 0}, step=0), |
| 549 | + call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=0), |
| 550 | + call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=1), |
| 551 | + call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 0}, step=0), |
| 552 | + call(metrics={"train_loss": ANY, "epoch": 1}, step=1), |
| 553 | + call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=2), |
| 554 | + call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=3), |
| 555 | + call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 1}, step=1), |
| 556 | + ] |
550 | 557 |
|
551 | 558 | def get_metrics_at_idx(idx):
|
552 | 559 | mock_call = mock_log_metrics.mock_calls[idx]
|
553 | 560 | return mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]
|
554 | 561 |
|
555 |
| - expected = {"valid_loss_0_step", "valid_loss_2"} |
556 |
| - assert set(get_metrics_at_idx(1)) == expected |
557 |
| - assert set(get_metrics_at_idx(2)) == expected |
558 |
| - |
559 |
| - assert get_metrics_at_idx(1)["valid_loss_0_step"] == model.val_losses[2] |
560 |
| - assert get_metrics_at_idx(2)["valid_loss_0_step"] == model.val_losses[3] |
561 |
| - |
562 |
| - assert set(get_metrics_at_idx(3)) == {"valid_loss_0_epoch", "valid_loss_1", "epoch"} |
563 |
| - |
564 |
| - assert get_metrics_at_idx(3)["valid_loss_1"] == torch.stack(model.val_losses[2:4]).mean() |
565 |
| - |
566 |
| - expected = {"valid_loss_0_step", "valid_loss_2"} |
567 |
| - assert set(get_metrics_at_idx(4)) == expected |
568 |
| - assert set(get_metrics_at_idx(5)) == expected |
569 |
| - |
570 |
| - assert get_metrics_at_idx(4)["valid_loss_0_step"] == model.val_losses[4] |
571 |
| - assert get_metrics_at_idx(5)["valid_loss_0_step"] == model.val_losses[5] |
572 |
| - |
573 |
| - assert set(get_metrics_at_idx(6)) == {"valid_loss_0_epoch", "valid_loss_1", "epoch"} |
574 |
| - |
575 |
| - assert get_metrics_at_idx(6)["valid_loss_1"] == torch.stack(model.val_losses[4:]).mean() |
| 562 | + assert get_metrics_at_idx(2)["valid_loss_0_step"] == model.val_losses[2] |
| 563 | + assert get_metrics_at_idx(3)["valid_loss_0_step"] == model.val_losses[3] |
| 564 | + assert get_metrics_at_idx(4)["valid_loss_1"] == torch.stack(model.val_losses[2:4]).mean() |
| 565 | + assert get_metrics_at_idx(6)["valid_loss_0_step"] == model.val_losses[4] |
| 566 | + assert get_metrics_at_idx(7)["valid_loss_0_step"] == model.val_losses[5] |
| 567 | + assert get_metrics_at_idx(8)["valid_loss_1"] == torch.stack(model.val_losses[4:]).mean() |
576 | 568 |
|
577 | 569 | results = trainer.test(model)
|
578 |
| - assert set(trainer.callback_metrics) == { |
579 |
| - "test_loss", |
580 |
| - } |
| 570 | + assert set(trainer.callback_metrics) == {"test_loss"} |
581 | 571 | assert set(results[0]) == {"test_loss"}
|
582 | 572 |
|
583 | 573 |
|
|
0 commit comments