Skip to content

Commit 28ee099

Browse files
Add feature implementation to datamodule for str method
First implementation scetch
1 parent 48279a7 commit 28ee099

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

src/lightning/pytorch/core/datamodule.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,27 @@
1616
import inspect
1717
from typing import IO, Any, Dict, Iterable, Optional, Union, cast
1818

19+
import pytorch_lightning as pl
20+
from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
1921
from lightning_utilities import apply_to_collection
22+
from pytorch_lightning.core.hooks import DataHooks
23+
from pytorch_lightning.core.mixins import HyperparametersMixin
24+
from pytorch_lightning.core.saving import _load_from_checkpoint
25+
from pytorch_lightning.utilities.model_helpers import _restricted_classmethod
26+
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
2027
from torch.utils.data import DataLoader, Dataset, IterableDataset
2128
from typing_extensions import Self
2229

23-
import lightning.pytorch as pl
24-
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
25-
from lightning.pytorch.core.hooks import DataHooks
26-
from lightning.pytorch.core.mixins import HyperparametersMixin
27-
from lightning.pytorch.core.saving import _load_from_checkpoint
28-
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
29-
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
30-
3130

3231
class LightningDataModule(DataHooks, HyperparametersMixin):
3332
"""A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is
3433
consistent data splits, data preparation and transforms across models.
3534
3635
Example::
3736
38-
import lightning as L
37+
import lightning.pytorch as L
3938
import torch.utils.data as data
40-
from lightning.pytorch.demos.boring_classes import RandomDataset
39+
from pytorch_lightning.demos.boring_classes import RandomDataset
4140
4241
class MyDataModule(L.LightningDataModule):
4342
def prepare_data(self):
@@ -243,3 +242,32 @@ def load_from_checkpoint(
243242
**kwargs,
244243
)
245244
return cast(Self, loaded)
245+
246+
def __str__(self) -> str:
247+
"""Return a string representation of the datasets that are setup.
248+
249+
Returns:
250+
A string representation of the datasets that are setup.
251+
252+
"""
253+
datasets_info = []
254+
255+
for attr_name in dir(self):
256+
attr = getattr(self, attr_name)
257+
258+
# Get Dataset information
259+
if isinstance(attr, Dataset):
260+
if hasattr(attr, "__len__"):
261+
datasets_info.append(f"{attr_name}, dataset size={len(attr)}")
262+
else:
263+
datasets_info.append(f"{attr_name}, dataset size=Unavailable")
264+
elif isinstance(attr, (list, tuple)) and all(isinstance(item, Dataset) for item in attr):
265+
if all(hasattr(item, "__len__") for item in attr):
266+
datasets_info.append(f"{attr_name}, dataset size={[len(ds) for ds in attr]}")
267+
else:
268+
datasets_info.append(f"{attr_name}, dataset size=Unavailable")
269+
270+
if not datasets_info:
271+
return "No datasets are set up."
272+
273+
return "\n".join(datasets_info)

0 commit comments

Comments
 (0)