|
16 | 16 | import inspect
|
17 | 17 | from typing import IO, Any, Dict, Iterable, Optional, Union, cast
|
18 | 18 |
|
| 19 | +import pytorch_lightning as pl |
| 20 | +from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH |
19 | 21 | from lightning_utilities import apply_to_collection
|
| 22 | +from pytorch_lightning.core.hooks import DataHooks |
| 23 | +from pytorch_lightning.core.mixins import HyperparametersMixin |
| 24 | +from pytorch_lightning.core.saving import _load_from_checkpoint |
| 25 | +from pytorch_lightning.utilities.model_helpers import _restricted_classmethod |
| 26 | +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS |
20 | 27 | from torch.utils.data import DataLoader, Dataset, IterableDataset
|
21 | 28 | from typing_extensions import Self
|
22 | 29 |
|
23 |
| -import lightning.pytorch as pl |
24 |
| -from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH |
25 |
| -from lightning.pytorch.core.hooks import DataHooks |
26 |
| -from lightning.pytorch.core.mixins import HyperparametersMixin |
27 |
| -from lightning.pytorch.core.saving import _load_from_checkpoint |
28 |
| -from lightning.pytorch.utilities.model_helpers import _restricted_classmethod |
29 |
| -from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS |
30 |
| - |
31 | 30 |
|
32 | 31 | class LightningDataModule(DataHooks, HyperparametersMixin):
|
33 | 32 | """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is
|
34 | 33 | consistent data splits, data preparation and transforms across models.
|
35 | 34 |
|
36 | 35 | Example::
|
37 | 36 |
|
38 |
| - import lightning as L |
| 37 | + import lightning.pytorch as L |
39 | 38 | import torch.utils.data as data
|
40 |
| - from lightning.pytorch.demos.boring_classes import RandomDataset |
| 39 | + from pytorch_lightning.demos.boring_classes import RandomDataset |
41 | 40 |
|
42 | 41 | class MyDataModule(L.LightningDataModule):
|
43 | 42 | def prepare_data(self):
|
@@ -243,3 +242,32 @@ def load_from_checkpoint(
|
243 | 242 | **kwargs,
|
244 | 243 | )
|
245 | 244 | return cast(Self, loaded)
|
| 245 | + |
| 246 | + def __str__(self) -> str: |
| 247 | + """Return a string representation of the datasets that are setup. |
| 248 | +
|
| 249 | + Returns: |
| 250 | + A string representation of the datasets that are setup. |
| 251 | +
|
| 252 | + """ |
| 253 | + datasets_info = [] |
| 254 | + |
| 255 | + for attr_name in dir(self): |
| 256 | + attr = getattr(self, attr_name) |
| 257 | + |
| 258 | + # Get Dataset information |
| 259 | + if isinstance(attr, Dataset): |
| 260 | + if hasattr(attr, "__len__"): |
| 261 | + datasets_info.append(f"{attr_name}, dataset size={len(attr)}") |
| 262 | + else: |
| 263 | + datasets_info.append(f"{attr_name}, dataset size=Unavailable") |
| 264 | + elif isinstance(attr, (list, tuple)) and all(isinstance(item, Dataset) for item in attr): |
| 265 | + if all(hasattr(item, "__len__") for item in attr): |
| 266 | + datasets_info.append(f"{attr_name}, dataset size={[len(ds) for ds in attr]}") |
| 267 | + else: |
| 268 | + datasets_info.append(f"{attr_name}, dataset size=Unavailable") |
| 269 | + |
| 270 | + if not datasets_info: |
| 271 | + return "No datasets are set up." |
| 272 | + |
| 273 | + return "\n".join(datasets_info) |
0 commit comments