| 
27 | 27 | from lightning.pytorch.core.hooks import DataHooks  | 
28 | 28 | from lightning.pytorch.core.mixins import HyperparametersMixin  | 
29 | 29 | from lightning.pytorch.core.saving import _load_from_checkpoint  | 
 | 30 | +from lightning.pytorch.utilities.exceptions import MisconfigurationException  | 
30 | 31 | from lightning.pytorch.utilities.model_helpers import _restricted_classmethod  | 
31 | 32 | from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS  | 
32 | 33 | 
 
  | 
@@ -247,25 +248,61 @@ def load_from_checkpoint(  | 
247 | 248 |         return cast(Self, loaded)  | 
248 | 249 | 
 
  | 
249 | 250 |     def __str__(self) -> str:  | 
250 |  | -        """Return a string representation of the datasets that are setup.  | 
 | 251 | +        """Return a string representation of the datasets that are set up.  | 
251 | 252 | 
  | 
252 | 253 |         Returns:  | 
253 | 254 |             A string representation of the datasets that are setup.  | 
254 | 255 | 
  | 
255 | 256 |         """  | 
256 |  | -        datasets_info: list[str] = []  | 
257 | 257 | 
 
  | 
258 |  | -        for attr_name in dir(self):  | 
259 |  | -            attr = getattr(self, attr_name)  | 
260 |  | - | 
261 |  | -            # Get Dataset information  | 
262 |  | -            if isinstance(attr, Dataset):  | 
263 |  | -                if isinstance(attr, Sized):  | 
264 |  | -                    datasets_info.append(f"name={attr_name}, size={len(attr)}")  | 
265 |  | -                else:  | 
266 |  | -                    datasets_info.append(f"name={attr_name}, size=Unavailable")  | 
267 |  | - | 
268 |  | -        if not datasets_info:  | 
269 |  | -            return "No datasets are set up."  | 
270 |  | - | 
271 |  | -        return os.linesep.join(datasets_info)  | 
 | 258 | +        def dataset_info(loader: DataLoader) -> tuple[str, str]:  | 
 | 259 | +            """Helper function to compute dataset information."""  | 
 | 260 | +            dataset = loader.dataset  | 
 | 261 | +            size: str  | 
 | 262 | +            size = str(len(dataset)) if isinstance(dataset, Sized) else "size: unknown"  | 
 | 263 | + | 
 | 264 | +            return str(dataset), size  | 
 | 265 | + | 
 | 266 | +        def loader_info(loader_instance: Union[DataLoader, Iterable[DataLoader]]) -> str:  | 
 | 267 | +            """Helper function to compute dataset information."""  | 
 | 268 | +            return apply_to_collection(loader_instance, tuple[str, str], dataset_info)  | 
 | 269 | + | 
 | 270 | +        dataloader_methods: list[tuple[str, str]] = [  | 
 | 271 | +            ("Train dataset", "train_dataloader"),  | 
 | 272 | +            ("Validation dataset", "val_dataloader"),  | 
 | 273 | +            ("Test dataset", "test_dataloader"),  | 
 | 274 | +            ("Prediction dataset", "predict_dataloader"),  | 
 | 275 | +        ]  | 
 | 276 | +        dataloader_info: dict[str, Union[tuple[str, str], Iterable[tuple[str, str]]]] = {}  | 
 | 277 | + | 
 | 278 | +        # Retrieve information for each dataloader method  | 
 | 279 | +        for method_pair in dataloader_methods:  | 
 | 280 | +            method_str, method_name = method_pair  | 
 | 281 | +            loader_method = getattr(self, method_name, None)  | 
 | 282 | + | 
 | 283 | +            if loader_method and callable(loader_method):  | 
 | 284 | +                try:  | 
 | 285 | +                    loader_instance = loader_method()  | 
 | 286 | +                    dataloader_info[method_str] = loader_info(loader_instance)  | 
 | 287 | +                except MisconfigurationException:  | 
 | 288 | +                    dataloader_info[method_str] = f"{method_str}: not implemented"  | 
 | 289 | +                except Exception as e:  | 
 | 290 | +                    dataloader_info[method_str] = f"{method_name}: error - {str(e)}"  | 
 | 291 | +            else:  | 
 | 292 | +                dataloader_info[method_str] = f"{method_name}: not callable"  | 
 | 293 | + | 
 | 294 | +        # Format the information  | 
 | 295 | +        dataloader_str: str = ""  | 
 | 296 | +        for method_str, method_info in dataloader_info.items():  | 
 | 297 | +            if isinstance(method_info, tuple[str, str]):  | 
 | 298 | +                dataloader_str += f"{{{method_str}: "  | 
 | 299 | +                dataloader_str += f"name={method_info[0]}, size={method_info[1]}"  | 
 | 300 | +                dataloader_str += f"}}{os.linesep}"  | 
 | 301 | +            else:  | 
 | 302 | +                dataloader_str += f"{{{method_str}: "  | 
 | 303 | +                for info in method_info:  | 
 | 304 | +                    dataloader_str += f"name={info[0]}, size={info[1]} ; "  | 
 | 305 | +                dataloader_str = dataloader_str[:-3]  | 
 | 306 | +                dataloader_str += f"}}{os.linesep}"  | 
 | 307 | + | 
 | 308 | +        return dataloader_str  | 
0 commit comments