|
23 | 23 |
|
24 | 24 | from pytorch_lightning import callbacks, Trainer |
25 | 25 | from pytorch_lightning.loggers import TensorBoardLogger |
| 26 | +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector |
26 | 27 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
27 | 28 | from tests.helpers import BoringModel, RandomDataset |
28 | 29 |
|
@@ -672,3 +673,29 @@ def val_dataloader(self): |
672 | 673 | enable_model_summary=False, |
673 | 674 | ) |
674 | 675 | trainer.fit(model) |
| 676 | + |
| 677 | + |
| 678 | +@pytest.mark.parametrize( |
| 679 | + ["kwargs", "expected"], |
| 680 | + [ |
| 681 | + ({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}), |
| 682 | + ( |
| 683 | + {"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}}, |
| 684 | + {"acc/dataloader_idx_0": 123}, |
| 685 | + ), |
| 686 | + ( |
| 687 | + {"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}}, |
| 688 | + {"acc/dataloader_idx_10": 321}, |
| 689 | + ), |
| 690 | + ( |
| 691 | + {"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}}, |
| 692 | + {"top_3_acc/dataloader_idx_3": 321}, |
| 693 | + ), |
| 694 | + # theoretical case, as `/dataloader_idx_3` would have been added |
| 695 | + ({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}), |
| 696 | + ], |
| 697 | +) |
| 698 | +def test_filter_metrics_for_dataloader(kwargs, expected): |
| 699 | + """Logged metrics should only include metrics from the concerned dataloader.""" |
| 700 | + actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs) |
| 701 | + assert actual == expected |
0 commit comments