Skip to content

Commit a58b3d7

Browse files
Added test cases for DataModule string function
Added alternative Boring Data Module implementations Added test cases for all possible options Added additional check for NotImplementedError in string function of DataModule
1 parent c4038e3 commit a58b3d7

File tree

3 files changed

+105
-5
lines changed

3 files changed

+105
-5
lines changed

src/lightning/pytorch/core/datamodule.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,17 +252,24 @@ def __str__(self) -> str:
252252
"""
253253
datasets_info = []
254254

255+
def len_implemented(obj):
256+
try:
257+
len(obj)
258+
return True
259+
except NotImplementedError:
260+
return False
261+
255262
for attr_name in dir(self):
256263
attr = getattr(self, attr_name)
257264

258265
# Get Dataset information
259266
if isinstance(attr, Dataset):
260-
if hasattr(attr, "__len__"):
261-
datasets_info.append(f"{attr_name}, dataset size={len(attr)}")
267+
if hasattr(attr, "__len__") and len_implemented(attr):
268+
datasets_info.append(f"name={attr_name}, size={len(attr)}")
262269
else:
263-
datasets_info.append(f"{attr_name}, dataset size=Unavailable")
270+
datasets_info.append(f"name={attr_name}, size=Unavailable")
264271

265272
if not datasets_info:
266273
return "No datasets are set up."
267274

268-
return "\n".join(datasets_info)
275+
return "\n".join(datasets_info) + "\n"

src/lightning/pytorch/demos/boring_classes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,38 @@ def predict_dataloader(self) -> DataLoader:
187187
return DataLoader(self.random_predict)
188188

189189

190+
class BoringDataModuleNoLen(LightningDataModule):
191+
"""
192+
.. warning:: This is meant for testing/debugging and is experimental.
193+
"""
194+
195+
def __init__(self) -> None:
196+
super().__init__()
197+
self.random_full = RandomIterableDataset(32, 64 * 4)
198+
199+
200+
class BoringDataModuleLenNotImplemented(LightningDataModule):
201+
"""
202+
.. warning:: This is meant for testing/debugging and is experimental.
203+
"""
204+
205+
def __init__(self) -> None:
206+
super().__init__()
207+
208+
class DS(Dataset):
209+
def __init__(self, size: int, length: int):
210+
self.len = length
211+
self.data = torch.randn(length, size)
212+
213+
def __getitem__(self, index: int) -> Tensor:
214+
return self.data[index]
215+
216+
def __len__(self) -> int:
217+
raise NotImplementedError
218+
219+
self.random_full = DS(32, 64 * 4)
220+
221+
190222
class ManualOptimBoringModel(BoringModel):
191223
"""
192224
.. warning:: This is meant for testing/debugging and is experimental.

tests/tests_pytorch/core/test_datamodules.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import torch
2323
from lightning.pytorch import LightningDataModule, Trainer, seed_everything
2424
from lightning.pytorch.callbacks import ModelCheckpoint
25-
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
25+
from lightning.pytorch.demos.boring_classes import (
26+
BoringDataModule,
27+
BoringDataModuleLenNotImplemented,
28+
BoringDataModuleNoLen,
29+
BoringModel,
30+
)
2631
from lightning.pytorch.profilers.simple import SimpleProfiler
2732
from lightning.pytorch.trainer.states import TrainerFn
2833
from lightning.pytorch.utilities import AttributeDict
@@ -510,3 +515,59 @@ def prepare_data(self):
510515
durations = profiler.recorded_durations[key]
511516
assert len(durations) == 1
512517
assert durations[0] > 0
518+
519+
520+
def test_datamodule_string_no_datasets():
521+
dm = BoringDataModule()
522+
del dm.random_full
523+
expected_output = "No datasets are set up."
524+
assert str(dm) == expected_output
525+
526+
527+
def test_datamodule_string_no_length():
528+
dm = BoringDataModuleNoLen()
529+
expected_output = "name=random_full, size=Unavailable\n"
530+
assert str(dm) == expected_output
531+
532+
533+
def test_datamodule_string_length_not_implemented():
534+
dm = BoringDataModuleLenNotImplemented()
535+
expected_output = "name=random_full, size=Unavailable\n"
536+
assert str(dm) == expected_output
537+
538+
539+
def test_datamodule_string_fit_setup():
540+
dm = BoringDataModule()
541+
dm.setup(stage="fit")
542+
543+
expected_outputs = ["name=random_full, size=256\n", "name=random_train, size=64\n", "name=random_val, size=64\n"]
544+
output = str(dm)
545+
for expected_output in expected_outputs:
546+
assert expected_output in output
547+
548+
549+
def test_datamodule_string_validation_setup():
550+
dm = BoringDataModule()
551+
dm.setup(stage="validate")
552+
expected_outputs = ["name=random_full, size=256\n", "name=random_val, size=64\n"]
553+
output = str(dm)
554+
for expected_output in expected_outputs:
555+
assert expected_output in output
556+
557+
558+
def test_datamodule_string_test_setup():
559+
dm = BoringDataModule()
560+
dm.setup(stage="test")
561+
expected_outputs = ["name=random_full, size=256\n", "name=random_test, size=64\n"]
562+
output = str(dm)
563+
for expected_output in expected_outputs:
564+
assert expected_output in output
565+
566+
567+
def test_datamodule_string_predict_setup():
568+
dm = BoringDataModule()
569+
dm.setup(stage="predict")
570+
expected_outputs = ["name=random_full, size=256\n", "name=random_predict, size=64\n"]
571+
output = str(dm)
572+
for expected_output in expected_outputs:
573+
assert expected_output in output

0 commit comments

Comments
 (0)