Skip to content

Commit de7faf9

Browse files
authored
Update evaluation logging test (#18896)
1 parent b8a96fe commit de7faf9

File tree

2 files changed

+24
-31
lines changed

2 files changed

+24
-31
lines changed

tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from contextlib import redirect_stdout
1919
from io import StringIO
2020
from unittest import mock
21-
from unittest.mock import call
21+
from unittest.mock import ANY, call
2222

2323
import numpy as np
2424
import pytest
@@ -527,57 +527,47 @@ def test_step(self, batch, batch_idx):
527527
trainer = Trainer(
528528
default_root_dir=tmpdir,
529529
logger=TensorBoardLogger(tmpdir),
530-
limit_train_batches=2,
530+
limit_train_batches=1,
531531
limit_val_batches=2,
532532
limit_test_batches=2,
533+
log_every_n_steps=1,
533534
max_epochs=2,
534535
)
535536

536537
# Train the model ⚡
537538
trainer.fit(model)
538539

539-
# hp_metric + 2 steps + epoch + 2 steps + epoch
540-
expected_num_calls = 1 + 2 + 1 + 2 + 1
541-
542540
assert set(trainer.callback_metrics) == {
543541
"train_loss",
544542
"valid_loss_0_epoch",
545543
"valid_loss_0",
546544
"valid_loss_1",
547545
}
548-
assert len(mock_log_metrics.mock_calls) == expected_num_calls
549-
assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0)
546+
assert mock_log_metrics.mock_calls == [
547+
call({"hp_metric": -1}, 0),
548+
call(metrics={"train_loss": ANY, "epoch": 0}, step=0),
549+
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=0),
550+
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=1),
551+
call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 0}, step=0),
552+
call(metrics={"train_loss": ANY, "epoch": 1}, step=1),
553+
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=2),
554+
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=3),
555+
call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 1}, step=1),
556+
]
550557

551558
def get_metrics_at_idx(idx):
552559
mock_call = mock_log_metrics.mock_calls[idx]
553560
return mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]
554561

555-
expected = {"valid_loss_0_step", "valid_loss_2"}
556-
assert set(get_metrics_at_idx(1)) == expected
557-
assert set(get_metrics_at_idx(2)) == expected
558-
559-
assert get_metrics_at_idx(1)["valid_loss_0_step"] == model.val_losses[2]
560-
assert get_metrics_at_idx(2)["valid_loss_0_step"] == model.val_losses[3]
561-
562-
assert set(get_metrics_at_idx(3)) == {"valid_loss_0_epoch", "valid_loss_1", "epoch"}
563-
564-
assert get_metrics_at_idx(3)["valid_loss_1"] == torch.stack(model.val_losses[2:4]).mean()
565-
566-
expected = {"valid_loss_0_step", "valid_loss_2"}
567-
assert set(get_metrics_at_idx(4)) == expected
568-
assert set(get_metrics_at_idx(5)) == expected
569-
570-
assert get_metrics_at_idx(4)["valid_loss_0_step"] == model.val_losses[4]
571-
assert get_metrics_at_idx(5)["valid_loss_0_step"] == model.val_losses[5]
572-
573-
assert set(get_metrics_at_idx(6)) == {"valid_loss_0_epoch", "valid_loss_1", "epoch"}
574-
575-
assert get_metrics_at_idx(6)["valid_loss_1"] == torch.stack(model.val_losses[4:]).mean()
562+
assert get_metrics_at_idx(2)["valid_loss_0_step"] == model.val_losses[2]
563+
assert get_metrics_at_idx(3)["valid_loss_0_step"] == model.val_losses[3]
564+
assert get_metrics_at_idx(4)["valid_loss_1"] == torch.stack(model.val_losses[2:4]).mean()
565+
assert get_metrics_at_idx(6)["valid_loss_0_step"] == model.val_losses[4]
566+
assert get_metrics_at_idx(7)["valid_loss_0_step"] == model.val_losses[5]
567+
assert get_metrics_at_idx(8)["valid_loss_1"] == torch.stack(model.val_losses[4:]).mean()
576568

577569
results = trainer.test(model)
578-
assert set(trainer.callback_metrics) == {
579-
"test_loss",
580-
}
570+
assert set(trainer.callback_metrics) == {"test_loss"}
581571
assert set(results[0]) == {"test_loss"}
582572

583573

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,9 @@ def training_step(self, batch, batch_idx):
19251925
def test_trainer_config_strategy(monkeypatch, trainer_kwargs, strategy_cls, accelerator_cls, devices):
19261926
if trainer_kwargs.get("accelerator") == "cuda":
19271927
mock_cuda_count(monkeypatch, trainer_kwargs["devices"])
1928+
if trainer_kwargs.get("accelerator") == "auto":
1929+
# current parametrizations assume non-CUDA env
1930+
mock_cuda_count(monkeypatch, 0)
19281931

19291932
trainer = Trainer(**trainer_kwargs)
19301933

0 commit comments

Comments
 (0)