|
22 | 22 | import torch
|
23 | 23 | from lightning.pytorch import LightningDataModule, Trainer, seed_everything
|
24 | 24 | from lightning.pytorch.callbacks import ModelCheckpoint
|
25 |
| -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel |
| 25 | +from lightning.pytorch.demos.boring_classes import ( |
| 26 | + BoringDataModule, |
| 27 | + BoringDataModuleLenNotImplemented, |
| 28 | + BoringDataModuleNoLen, |
| 29 | + BoringModel, |
| 30 | +) |
26 | 31 | from lightning.pytorch.profilers.simple import SimpleProfiler
|
27 | 32 | from lightning.pytorch.trainer.states import TrainerFn
|
28 | 33 | from lightning.pytorch.utilities import AttributeDict
|
@@ -510,3 +515,59 @@ def prepare_data(self):
|
510 | 515 | durations = profiler.recorded_durations[key]
|
511 | 516 | assert len(durations) == 1
|
512 | 517 | assert durations[0] > 0
|
| 518 | + |
| 519 | + |
| 520 | +def test_datamodule_string_no_datasets(): |
| 521 | + dm = BoringDataModule() |
| 522 | + del dm.random_full |
| 523 | + expected_output = "No datasets are set up." |
| 524 | + assert str(dm) == expected_output |
| 525 | + |
| 526 | + |
| 527 | +def test_datamodule_string_no_length(): |
| 528 | + dm = BoringDataModuleNoLen() |
| 529 | + expected_output = "name=random_full, size=Unavailable\n" |
| 530 | + assert str(dm) == expected_output |
| 531 | + |
| 532 | + |
| 533 | +def test_datamodule_string_length_not_implemented(): |
| 534 | + dm = BoringDataModuleLenNotImplemented() |
| 535 | + expected_output = "name=random_full, size=Unavailable\n" |
| 536 | + assert str(dm) == expected_output |
| 537 | + |
| 538 | + |
| 539 | +def test_datamodule_string_fit_setup(): |
| 540 | + dm = BoringDataModule() |
| 541 | + dm.setup(stage="fit") |
| 542 | + |
| 543 | + expected_outputs = ["name=random_full, size=256\n", "name=random_train, size=64\n", "name=random_val, size=64\n"] |
| 544 | + output = str(dm) |
| 545 | + for expected_output in expected_outputs: |
| 546 | + assert expected_output in output |
| 547 | + |
| 548 | + |
| 549 | +def test_datamodule_string_validation_setup(): |
| 550 | + dm = BoringDataModule() |
| 551 | + dm.setup(stage="validate") |
| 552 | + expected_outputs = ["name=random_full, size=256\n", "name=random_val, size=64\n"] |
| 553 | + output = str(dm) |
| 554 | + for expected_output in expected_outputs: |
| 555 | + assert expected_output in output |
| 556 | + |
| 557 | + |
| 558 | +def test_datamodule_string_test_setup(): |
| 559 | + dm = BoringDataModule() |
| 560 | + dm.setup(stage="test") |
| 561 | + expected_outputs = ["name=random_full, size=256\n", "name=random_test, size=64\n"] |
| 562 | + output = str(dm) |
| 563 | + for expected_output in expected_outputs: |
| 564 | + assert expected_output in output |
| 565 | + |
| 566 | + |
| 567 | +def test_datamodule_string_predict_setup(): |
| 568 | + dm = BoringDataModule() |
| 569 | + dm.setup(stage="predict") |
| 570 | + expected_outputs = ["name=random_full, size=256\n", "name=random_predict, size=64\n"] |
| 571 | + output = str(dm) |
| 572 | + for expected_output in expected_outputs: |
| 573 | + assert expected_output in output |
0 commit comments