|
14 | 14 | """LightningDataModule for loading DataLoaders with ease."""
|
15 | 15 |
|
16 | 16 | import inspect
|
17 |
| -from collections.abc import Iterable |
| 17 | +import os |
| 18 | +from collections.abc import Iterable, Sized |
18 | 19 | from typing import IO, Any, Optional, Union, cast
|
19 | 20 |
|
20 | 21 | from lightning_utilities import apply_to_collection
|
@@ -244,3 +245,75 @@ def load_from_checkpoint(
|
244 | 245 | **kwargs,
|
245 | 246 | )
|
246 | 247 | return cast(Self, loaded)
|
| 248 | + |
| 249 | + def __str__(self) -> str: |
| 250 | + """Return a string representation of the datasets that are set up. |
| 251 | +
|
| 252 | + Returns: |
| 253 | + A string representation of the datasets that are setup. |
| 254 | +
|
| 255 | + """ |
| 256 | + |
| 257 | + class dataset_info: |
| 258 | + def __init__(self, available: bool, length: str) -> None: |
| 259 | + self.available = available |
| 260 | + self.length = length |
| 261 | + |
| 262 | + def retrieve_dataset_info(loader: DataLoader) -> dataset_info: |
| 263 | + """Helper function to compute dataset information.""" |
| 264 | + dataset = loader.dataset |
| 265 | + size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA" |
| 266 | + |
| 267 | + return dataset_info(True, size) |
| 268 | + |
| 269 | + def loader_info( |
| 270 | + loader: Union[DataLoader, Iterable[DataLoader]], |
| 271 | + ) -> Union[dataset_info, Iterable[dataset_info]]: |
| 272 | + """Helper function to compute dataset information.""" |
| 273 | + return apply_to_collection(loader, DataLoader, retrieve_dataset_info) |
| 274 | + |
| 275 | + def extract_loader_info(methods: list[tuple[str, str]]) -> dict: |
| 276 | + """Helper function to extract information for each dataloader method.""" |
| 277 | + info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {} |
| 278 | + for loader_name, func_name in methods: |
| 279 | + loader_method = getattr(self, func_name, None) |
| 280 | + |
| 281 | + try: |
| 282 | + loader = loader_method() # type: ignore |
| 283 | + info[loader_name] = loader_info(loader) |
| 284 | + except Exception: |
| 285 | + info[loader_name] = dataset_info(False, "") |
| 286 | + |
| 287 | + return info |
| 288 | + |
| 289 | + def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str: |
| 290 | + """Helper function to format loader information.""" |
| 291 | + output = [] |
| 292 | + for loader_name, loader_info in info.items(): |
| 293 | + # Single dataset |
| 294 | + if isinstance(loader_info, dataset_info): |
| 295 | + loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}" |
| 296 | + # Iterable of datasets |
| 297 | + else: |
| 298 | + loader_info_formatted = " ; ".join( |
| 299 | + "None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}" |
| 300 | + for i, loader_info_i in enumerate(loader_info, start=1) |
| 301 | + ) |
| 302 | + |
| 303 | + output.append(f"{{{loader_name}: {loader_info_formatted}}}") |
| 304 | + |
| 305 | + return os.linesep.join(output) |
| 306 | + |
| 307 | + # Available dataloader methods |
| 308 | + datamodule_loader_methods: list[tuple[str, str]] = [ |
| 309 | + ("Train dataloader", "train_dataloader"), |
| 310 | + ("Validation dataloader", "val_dataloader"), |
| 311 | + ("Test dataloader", "test_dataloader"), |
| 312 | + ("Predict dataloader", "predict_dataloader"), |
| 313 | + ] |
| 314 | + |
| 315 | + # Retrieve information for each dataloader method |
| 316 | + dataloader_info = extract_loader_info(datamodule_loader_methods) |
| 317 | + # Format the information |
| 318 | + dataloader_str = format_loader_info(dataloader_info) |
| 319 | + return dataloader_str |
0 commit comments