Skip to content

Commit 1f4a77c

Browse files
authored
improve importing demos (#20446)
1 parent 75d7357 commit 1f4a77c

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed
Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,15 @@
1-
from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM # noqa: F401
2-
from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 # noqa: F401
1+
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, DemoModel
2+
from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM
3+
from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2
4+
5+
__all__ = [
6+
"LightningLSTM",
7+
"SequenceSampler",
8+
"SimpleLSTM",
9+
"LightningTransformer",
10+
"Transformer",
11+
"WikiText2",
12+
"BoringModel",
13+
"BoringDataModule",
14+
"DemoModel",
15+
]

tests/tests_pytorch/test_cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,8 @@ def test_lightning_cli_print_config():
478478
"any.py",
479479
"predict",
480480
"--seed_everything=1234",
481-
"--model=lightning.pytorch.demos.boring_classes.BoringModel",
482-
"--data=lightning.pytorch.demos.boring_classes.BoringDataModule",
481+
"--model=lightning.pytorch.demos.BoringModel",
482+
"--data=lightning.pytorch.demos.BoringDataModule",
483483
"--print_config",
484484
]
485485
out = StringIO()
@@ -492,8 +492,8 @@ def test_lightning_cli_print_config():
492492

493493
outval = yaml.safe_load(text)
494494
assert outval["seed_everything"] == 1234
495-
assert outval["model"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringModel"
496-
assert outval["data"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringDataModule"
495+
assert outval["model"]["class_path"] == "lightning.pytorch.demos.BoringModel"
496+
assert outval["data"]["class_path"] == "lightning.pytorch.demos.BoringDataModule"
497497
assert outval["ckpt_path"] is None
498498

499499

0 commit comments

Comments
 (0)