diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 0c7a9840e2219..ff84c2fd8b199 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -14,7 +14,8 @@ """LightningDataModule for loading DataLoaders with ease.""" import inspect -from collections.abc import Iterable +import os +from collections.abc import Iterable, Sized from typing import IO, Any, Optional, Union, cast from lightning_utilities import apply_to_collection @@ -244,3 +245,75 @@ def load_from_checkpoint( **kwargs, ) return cast(Self, loaded) + + def __str__(self) -> str: + """Return a string representation of the datasets that are set up. + + Returns: + A string representation of the datasets that are setup. + + """ + + class dataset_info: + def __init__(self, available: bool, length: str) -> None: + self.available = available + self.length = length + + def retrieve_dataset_info(loader: DataLoader) -> dataset_info: + """Helper function to compute dataset information.""" + dataset = loader.dataset + size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA" + + return dataset_info(True, size) + + def loader_info( + loader: Union[DataLoader, Iterable[DataLoader]], + ) -> Union[dataset_info, Iterable[dataset_info]]: + """Helper function to compute dataset information.""" + return apply_to_collection(loader, DataLoader, retrieve_dataset_info) + + def extract_loader_info(methods: list[tuple[str, str]]) -> dict: + """Helper function to extract information for each dataloader method.""" + info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {} + for loader_name, func_name in methods: + loader_method = getattr(self, func_name, None) + + try: + loader = loader_method() # type: ignore + info[loader_name] = loader_info(loader) + except Exception: + info[loader_name] = dataset_info(False, "") + + return info + + def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str: + """Helper function to format loader information.""" + output = [] + for loader_name, loader_info in info.items(): + # Single dataset + if isinstance(loader_info, dataset_info): + loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}" + # Iterable of datasets + else: + loader_info_formatted = " ; ".join( + "None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}" + for i, loader_info_i in enumerate(loader_info, start=1) + ) + + output.append(f"{{{loader_name}: {loader_info_formatted}}}") + + return os.linesep.join(output) + + # Available dataloader methods + datamodule_loader_methods: list[tuple[str, str]] = [ + ("Train dataloader", "train_dataloader"), + ("Validation dataloader", "val_dataloader"), + ("Test dataloader", "test_dataloader"), + ("Predict dataloader", "predict_dataloader"), + ] + + # Retrieve information for each dataloader method + dataloader_info = extract_loader_info(datamodule_loader_methods) + # Format the information + dataloader_str = format_loader_info(dataloader_info) + return dataloader_str diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index 589524e1960b2..3855f31898b81 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from typing import Any, Optional import torch import torch.nn as nn import torch.nn.functional as F +from lightning_utilities import apply_to_collection from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -188,6 +189,86 @@ def predict_dataloader(self) -> DataLoader: return DataLoader(self.random_predict) +class BoringDataModuleNoLen(LightningDataModule): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self) -> None: + super().__init__() + + def setup(self, stage: str) -> None: + if stage == "fit": + self.random_train = RandomIterableDataset(32, 512) + + if stage in ("fit", "validate"): + self.random_val = RandomIterableDataset(32, 128) + + if stage == "test": + self.random_test = RandomIterableDataset(32, 256) + + if stage == "predict": + self.random_predict = RandomIterableDataset(32, 64) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.random_train) + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.random_val) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.random_test) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(self.random_predict) + + +class IterableBoringDataModule(LightningDataModule): + def __init__(self) -> None: + super().__init__() + + def setup(self, stage: str) -> None: + if stage == "fit": + self.train_datasets = [ + RandomDataset(4, 16), + RandomIterableDataset(4, 16), + ] + + if stage in ("fit", "validate"): + self.val_datasets = [ + RandomDataset(4, 32), + RandomIterableDataset(4, 32), + ] + + if stage == "test": + self.test_datasets = [ + RandomDataset(4, 64), + RandomIterableDataset(4, 64), + ] + + if stage == "predict": + self.predict_datasets = [ + RandomDataset(4, 128), + RandomIterableDataset(4, 128), + ] + + def train_dataloader(self) -> Iterable[DataLoader]: + combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x)) + return combined_train + + def val_dataloader(self) -> DataLoader: + combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x)) + return combined_val + + def test_dataloader(self) -> DataLoader: + combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x)) + return combined_test + + def predict_dataloader(self) -> DataLoader: + combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x)) + return combined_predict + + class ManualOptimBoringModel(BoringModel): """ .. warning:: This is meant for testing/debugging and is experimental. diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 5f468156be716..b3ccd88aae704 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle from argparse import Namespace from dataclasses import dataclass @@ -22,7 +23,12 @@ import torch from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel +from lightning.pytorch.demos.boring_classes import ( + BoringDataModule, + BoringDataModuleNoLen, + BoringModel, + IterableBoringDataModule, +) from lightning.pytorch.profilers.simple import SimpleProfiler from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import AttributeDict @@ -510,3 +516,107 @@ def prepare_data(self): durations = profiler.recorded_durations[key] assert len(durations) == 1 assert durations[0] > 0 + + +def test_datamodule_string_not_available(): + dm = BoringDataModule() + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + out = str(dm) + + assert out == expected_output + + +def test_datamodule_string_fit_setup(): + dm = BoringDataModule() + dm.setup(stage="fit") + + expected_output = ( + f"{{Train dataloader: size=64}}{os.linesep}" + f"{{Validation dataloader: size=64}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_validation_setup(): + dm = BoringDataModule() + dm.setup(stage="validate") + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: size=64}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_test_setup(): + dm = BoringDataModule() + dm.setup(stage="test") + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: size=64}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_predict_setup(): + dm = BoringDataModule() + dm.setup(stage="predict") + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: size=64}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_no_len(): + dm = BoringDataModuleNoLen() + dm.setup("fit") + + expected_output = ( + f"{{Train dataloader: size=NA}}{os.linesep}" + f"{{Validation dataloader: size=NA}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert output == expected_output + + +def test_datamodule_string_iterable(): + dm = IterableBoringDataModule() + dm.setup("fit") + + expected_output = ( + f"{{Train dataloader: 1. size=16 ; 2. size=NA}}{os.linesep}" + f"{{Validation dataloader: 1. size=32 ; 2. size=NA}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert output == expected_output