Skip to content

Commit 6d05cfe

Browse files
Add string method to datamodule
Switched from Dataset based implementation to Dataloader based implementation
1 parent e03aefb commit 6d05cfe

File tree

1 file changed

+53
-16
lines changed

1 file changed

+53
-16
lines changed

src/lightning/pytorch/core/datamodule.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from lightning.pytorch.core.hooks import DataHooks
2828
from lightning.pytorch.core.mixins import HyperparametersMixin
2929
from lightning.pytorch.core.saving import _load_from_checkpoint
30+
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3031
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
3132
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
3233

@@ -247,25 +248,61 @@ def load_from_checkpoint(
247248
return cast(Self, loaded)
248249

249250
def __str__(self) -> str:
250-
"""Return a string representation of the datasets that are setup.
251+
"""Return a string representation of the datasets that are set up.
251252
252253
Returns:
253254
A string representation of the datasets that are setup.
254255
255256
"""
256-
datasets_info: list[str] = []
257257

258-
for attr_name in dir(self):
259-
attr = getattr(self, attr_name)
260-
261-
# Get Dataset information
262-
if isinstance(attr, Dataset):
263-
if isinstance(attr, Sized):
264-
datasets_info.append(f"name={attr_name}, size={len(attr)}")
265-
else:
266-
datasets_info.append(f"name={attr_name}, size=Unavailable")
267-
268-
if not datasets_info:
269-
return "No datasets are set up."
270-
271-
return os.linesep.join(datasets_info)
258+
def dataset_info(loader: DataLoader) -> tuple[str, str]:
259+
"""Helper function to compute dataset information."""
260+
dataset = loader.dataset
261+
size: str
262+
size = str(len(dataset)) if isinstance(dataset, Sized) else "size: unknown"
263+
264+
return str(dataset), size
265+
266+
def loader_info(loader_instance: Union[DataLoader, Iterable[DataLoader]]) -> str:
267+
"""Helper function to compute dataset information."""
268+
return apply_to_collection(loader_instance, tuple[str, str], dataset_info)
269+
270+
dataloader_methods: list[tuple[str, str]] = [
271+
("Train dataset", "train_dataloader"),
272+
("Validation dataset", "val_dataloader"),
273+
("Test dataset", "test_dataloader"),
274+
("Prediction dataset", "predict_dataloader"),
275+
]
276+
dataloader_info: dict[str, Union[tuple[str, str], Iterable[tuple[str, str]]]] = {}
277+
278+
# Retrieve information for each dataloader method
279+
for method_pair in dataloader_methods:
280+
method_str, method_name = method_pair
281+
loader_method = getattr(self, method_name, None)
282+
283+
if loader_method and callable(loader_method):
284+
try:
285+
loader_instance = loader_method()
286+
dataloader_info[method_str] = loader_info(loader_instance)
287+
except MisconfigurationException:
288+
dataloader_info[method_str] = f"{method_str}: not implemented"
289+
except Exception as e:
290+
dataloader_info[method_str] = f"{method_name}: error - {str(e)}"
291+
else:
292+
dataloader_info[method_str] = f"{method_name}: not callable"
293+
294+
# Format the information
295+
dataloader_str: str = ""
296+
for method_str, method_info in dataloader_info.items():
297+
if isinstance(method_info, tuple[str, str]):
298+
dataloader_str += f"{{{method_str}: "
299+
dataloader_str += f"name={method_info[0]}, size={method_info[1]}"
300+
dataloader_str += f"}}{os.linesep}"
301+
else:
302+
dataloader_str += f"{{{method_str}: "
303+
for info in method_info:
304+
dataloader_str += f"name={info[0]}, size={info[1]} ; "
305+
dataloader_str = dataloader_str[:-3]
306+
dataloader_str += f"}}{os.linesep}"
307+
308+
return dataloader_str

0 commit comments

Comments
 (0)