Skip to content

Commit 21029d2

Browse files
Implementing str method for datamodule
Refactored code and made it more readable by implementing more abstarct fucntion methods Adjusted tests Removed debug statements Removed TODO comments
1 parent b08e1fe commit 21029d2

File tree

2 files changed

+54
-46
lines changed

2 files changed

+54
-46
lines changed

src/lightning/pytorch/core/datamodule.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -254,52 +254,68 @@ def __str__(self) -> str:
254254
255255
"""
256256

257-
def dataset_info(loader: DataLoader) -> tuple[str, str]:
257+
class dataset_info:
258+
def __init__(self, available: str, length: str) -> None:
259+
self.available = available
260+
self.length = length
261+
262+
def retrieve_dataset_info(loader: DataLoader) -> dataset_info:
258263
"""Helper function to compute dataset information."""
259264
dataset = loader.dataset
260265
size: str = str(len(dataset)) if isinstance(dataset, Sized) else "unknown"
266+
output = dataset_info("yes", size)
267+
return output
261268

262-
return "yes", size
263-
264-
def loader_info(loader_instance: Union[DataLoader, Iterable[DataLoader]]) -> str:
269+
def loader_info(
270+
loader_instance: Union[DataLoader, Iterable[DataLoader]],
271+
) -> Union[dataset_info, Iterable[dataset_info]]:
265272
"""Helper function to compute dataset information."""
266-
result = apply_to_collection(loader_instance, DataLoader, dataset_info)
273+
result = apply_to_collection(loader_instance, DataLoader, retrieve_dataset_info)
267274

268275
return result
269276

277+
def extract_loader_info(methods: list[tuple[str, str]]) -> dict:
278+
"""Helper function to extract information for each dataloader method."""
279+
info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {}
280+
for method_str, function_name in methods:
281+
loader_method = getattr(self, function_name, None)
282+
283+
try:
284+
loader_instance = loader_method()
285+
info[method_str] = loader_info(loader_instance)
286+
except Exception:
287+
info[method_str] = dataset_info("no", "unknown")
288+
289+
return info
290+
291+
def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str:
292+
"""Helper function to format loader information."""
293+
lines = []
294+
for method_str, method_info in info.items():
295+
# Single dataset
296+
if isinstance(method_info, dataset_info):
297+
data_info = f"{{{method_str}: available={method_info.available}, size={method_info.length}}}"
298+
lines.append(data_info)
299+
# Iterable of datasets
300+
else:
301+
itr_data_info = " ; ".join(
302+
f"{i}. available={dataset.available}, size={dataset.length}"
303+
for i, dataset in enumerate(method_info, start=1)
304+
)
305+
lines.append(f"{{{method_str}: {itr_data_info}}}")
306+
307+
return os.linesep.join(lines)
308+
309+
# Available dataloader methods
270310
dataloader_methods: list[tuple[str, str]] = [
271311
("Train dataset", "train_dataloader"),
272312
("Validation dataset", "val_dataloader"),
273313
("Test dataset", "test_dataloader"),
274314
("Prediction dataset", "predict_dataloader"),
275315
]
276-
dataloader_info: dict[str, Union[tuple[str, str], Iterable[tuple[str, str]]]] = {}
277316

278317
# Retrieve information for each dataloader method
279-
for method_pair in dataloader_methods:
280-
method_str, method_name = method_pair
281-
loader_method = getattr(self, method_name, None)
282-
283-
try:
284-
loader_instance = loader_method()
285-
dataloader_info[method_str] = loader_info(loader_instance)
286-
except Exception:
287-
dataloader_info[method_str] = ("no", "unknown")
288-
318+
dataloader_info = extract_loader_info(dataloader_methods)
289319
# Format the information
290-
dataloader_str: str = ""
291-
for method_str, method_info in dataloader_info.items():
292-
# Single data set
293-
if isinstance(method_info, tuple):
294-
dataloader_str += f"{{{method_str}: "
295-
dataloader_str += f"available={method_info[0]}, size={method_info[1]}"
296-
dataloader_str += f"}}{os.linesep}"
297-
else:
298-
# Iterable of datasets
299-
dataloader_str += f"{{{method_str}: "
300-
for i, info in enumerate(method_info, start=1):
301-
dataloader_str += f"{i}. available={info[0]}, size={info[1]} ; "
302-
dataloader_str = dataloader_str[:-3]
303-
dataloader_str += f"}}{os.linesep}"
304-
320+
dataloader_str = format_loader_info(dataloader_info)
305321
return dataloader_str

tests/tests_pytorch/core/test_datamodules.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -518,22 +518,20 @@ def prepare_data(self):
518518
assert durations[0] > 0
519519

520520

521-
# TODO: Remove last os.linesep
522521
def test_datamodule_string_not_available():
523522
dm = BoringDataModule()
524523

525524
expected_output = (
526525
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
527526
f"{{Validation dataset: available=no, size=unknown}}{os.linesep}"
528527
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
529-
f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}"
528+
f"{{Prediction dataset: available=no, size=unknown}}"
530529
)
531530
out = str(dm)
532531

533532
assert out == expected_output
534533

535534

536-
# TODO Remove prints
537535
def test_datamodule_string_fit_setup():
538536
dm = BoringDataModule()
539537
dm.setup(stage="fit")
@@ -542,16 +540,10 @@ def test_datamodule_string_fit_setup():
542540
f"{{Train dataset: available=yes, size=64}}{os.linesep}"
543541
f"{{Validation dataset: available=yes, size=64}}{os.linesep}"
544542
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
545-
f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}"
543+
f"{{Prediction dataset: available=no, size=unknown}}"
546544
)
547545
output = str(dm)
548546

549-
print()
550-
print(repr(expected_output))
551-
print()
552-
print(repr(output))
553-
print()
554-
555547
assert expected_output == output
556548

557549

@@ -563,7 +555,7 @@ def test_datamodule_string_validation_setup():
563555
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
564556
f"{{Validation dataset: available=yes, size=64}}{os.linesep}"
565557
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
566-
f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}"
558+
f"{{Prediction dataset: available=no, size=unknown}}"
567559
)
568560
output = str(dm)
569561

@@ -578,7 +570,7 @@ def test_datamodule_string_test_setup():
578570
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
579571
f"{{Validation dataset: available=no, size=unknown}}{os.linesep}"
580572
f"{{Test dataset: available=yes, size=64}}{os.linesep}"
581-
f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}"
573+
f"{{Prediction dataset: available=no, size=unknown}}"
582574
)
583575
output = str(dm)
584576

@@ -593,7 +585,7 @@ def test_datamodule_string_predict_setup():
593585
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
594586
f"{{Validation dataset: available=no, size=unknown}}{os.linesep}"
595587
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
596-
f"{{Prediction dataset: available=yes, size=64}}{os.linesep}"
588+
f"{{Prediction dataset: available=yes, size=64}}"
597589
)
598590
output = str(dm)
599591

@@ -608,7 +600,7 @@ def test_datamodule_string_no_len():
608600
f"{{Train dataset: available=yes, size=unknown}}{os.linesep}"
609601
f"{{Validation dataset: available=yes, size=unknown}}{os.linesep}"
610602
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
611-
f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}"
603+
f"{{Prediction dataset: available=no, size=unknown}}"
612604
)
613605
output = str(dm)
614606

@@ -623,7 +615,7 @@ def test_datamodule_string_iterable():
623615
f"{{Train dataset: 1. available=yes, size=16 ; 2. available=yes, size=unknown}}{os.linesep}"
624616
f"{{Validation dataset: 1. available=yes, size=32 ; 2. available=yes, size=unknown}}{os.linesep}"
625617
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
626-
f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}"
618+
f"{{Prediction dataset: available=no, size=unknown}}"
627619
)
628620
output = str(dm)
629621

0 commit comments

Comments
 (0)