Skip to content

Commit 0a928e8

Browse files
authored
Reuse code in demos.BoringModel (#16242)
1 parent 5604226 commit 0a928e8

24 files changed

+77
-302
lines changed

examples/app_components/python/pl_script.py

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,5 @@
1-
import torch
2-
from torch.utils.data import DataLoader, Dataset
3-
4-
from pytorch_lightning import LightningModule, Trainer
5-
6-
7-
class RandomDataset(Dataset):
8-
def __init__(self, size: int, length: int):
9-
self.len = length
10-
self.data = torch.randn(length, size)
11-
12-
def __getitem__(self, index):
13-
return self.data[index]
14-
15-
def __len__(self):
16-
return self.len
17-
18-
19-
class BoringModel(LightningModule):
20-
def __init__(self):
21-
super().__init__()
22-
self.layer = torch.nn.Linear(32, 2)
23-
24-
def forward(self, x):
25-
return self.layer(x)
26-
27-
def loss(self, batch, prediction):
28-
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
29-
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
30-
31-
def training_step(self, batch, batch_idx):
32-
output = self(batch)
33-
loss = self.loss(batch, output)
34-
return {"loss": loss}
35-
36-
def validation_step(self, batch, batch_idx):
37-
output = self(batch)
38-
loss = self.loss(batch, output)
39-
return {"x": loss}
40-
41-
def test_step(self, batch, batch_idx):
42-
output = self(batch)
43-
loss = self.loss(batch, output)
44-
return {"y": loss}
45-
46-
def configure_optimizers(self):
47-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
48-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
49-
return [optimizer], [lr_scheduler]
50-
51-
def train_dataloader(self):
52-
return DataLoader(RandomDataset(32, 64))
53-
54-
val_dataloader = train_dataloader
55-
test_dataloader = train_dataloader
56-
predict_dataloader = train_dataloader
57-
1+
from pytorch_lightning import Trainer
2+
from pytorch_lightning.demos.boring_classes import BoringModel
583

594
if __name__ == "__main__":
605
model = BoringModel()

examples/app_components/python/pytorch_lightning_script.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

src/pytorch_lightning/demos/boring_classes.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -83,34 +83,39 @@ def __init__(self) -> None:
8383
- subclass
8484
- modify the behavior for what you want
8585
86-
class TestModel(BaseTestModel):
87-
def training_step(...):
88-
# do your own thing
86+
Example::
8987
90-
or:
88+
class TestModel(BoringModel):
89+
def training_step(self, ...):
90+
... # do your own thing
9191
92-
model = BaseTestModel()
93-
model.training_epoch_end = None
92+
training_epoch_end = None # disable hook
93+
94+
or
95+
96+
Example::
97+
98+
model = BoringModel()
99+
model.training_epoch_end = None # disable hook
94100
"""
95101
super().__init__()
96102
self.layer = torch.nn.Linear(32, 2)
97103

98104
def forward(self, x: Tensor) -> Tensor:
99105
return self.layer(x)
100106

101-
def loss(self, batch: Tensor, preds: Tensor) -> Tensor:
107+
def loss(self, preds: Tensor, labels: Optional[Tensor] = None) -> Tensor:
108+
if labels is None:
109+
labels = torch.ones_like(preds)
102110
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
103-
return torch.nn.functional.mse_loss(preds, torch.ones_like(preds))
111+
return torch.nn.functional.mse_loss(preds, labels)
104112

105-
def step(self, x: Tensor) -> Tensor:
106-
x = self(x)
107-
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
108-
return out
113+
def step(self, batch: Tensor) -> Tensor:
114+
output = self(batch)
115+
return self.loss(output)
109116

110117
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT:
111-
output = self(batch)
112-
loss = self.loss(batch, output)
113-
return {"loss": loss}
118+
return {"loss": self.step(batch)}
114119

115120
def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT:
116121
return training_step_outputs
@@ -120,18 +125,14 @@ def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
120125
torch.stack([x["loss"] for x in outputs]).mean()
121126

122127
def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
123-
output = self(batch)
124-
loss = self.loss(batch, output)
125-
return {"x": loss}
128+
return {"x": self.step(batch)}
126129

127130
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
128131
outputs = cast(List[Dict[str, Tensor]], outputs)
129132
torch.stack([x["x"] for x in outputs]).mean()
130133

131134
def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
132-
output = self(batch)
133-
loss = self.loss(batch, output)
134-
return {"y": loss}
135+
return {"y": self.step(batch)}
135136

136137
def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
137138
outputs = cast(List[Dict[str, Tensor]], outputs)
@@ -194,8 +195,7 @@ def __init__(self) -> None:
194195
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT:
195196
opt = self.optimizers()
196197
assert isinstance(opt, (Optimizer, LightningOptimizer))
197-
output = self(batch)
198-
loss = self.loss(batch, output)
198+
loss = self.step(batch)
199199
opt.zero_grad()
200200
self.manual_backward(loss)
201201
opt.step()

tests/tests_pytorch/accelerators/test_ipu.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,13 @@
3939

4040
class IPUModel(BoringModel):
4141
def training_step(self, batch, batch_idx):
42-
output = self(batch)
43-
loss = self.loss(batch, output)
44-
return loss
42+
return self.step(batch)
4543

4644
def validation_step(self, batch, batch_idx):
47-
output = self(batch)
48-
loss = self.loss(batch, output)
49-
return loss
45+
return self.step(batch)
5046

5147
def test_step(self, batch, batch_idx):
52-
output = self(batch)
53-
loss = self.loss(batch, output)
54-
return loss
48+
return self.step(batch)
5549

5650
def training_epoch_end(self, outputs) -> None:
5751
pass

tests/tests_pytorch/accelerators/test_tpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ def on_train_batch_start(self, batch, batch_idx):
144144
def training_step(self, batch, batch_idx):
145145
self.called["training_step"] += 1
146146
opt = self.optimizers()
147-
output = self.layer(batch)
148-
loss = self.loss(batch, output)
147+
loss = self.step(batch)
149148

150149
if self.should_update:
151150
self.manual_backward(loss)

tests/tests_pytorch/callbacks/test_finetuning_callback.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ def __init__(self):
5252
self.layer = torch.nn.Linear(32, 2)
5353
self.backbone.has_been_used = False
5454

55-
def training_step(self, batch, batch_idx):
56-
output = self(batch)
57-
loss = self.loss(batch, output)
58-
return {"loss": loss}
59-
6055
def forward(self, x):
6156
self.backbone.has_been_used = True
6257
x = self.backbone(x)
@@ -101,11 +96,6 @@ def __init__(self):
10196
self.layer = None
10297
self.backbone.has_been_used = False
10398

104-
def training_step(self, batch, batch_idx):
105-
output = self(batch)
106-
loss = self.loss(batch, output)
107-
return {"loss": loss}
108-
10999
def forward(self, x):
110100
self.backbone.has_been_used = True
111101
x = self.backbone(x)

tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def __init__(
6565
def training_step(self, batch, batch_idx):
6666
if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
6767
raise Exception("SWA crash test")
68-
output = self.forward(batch)
69-
loss = self.loss(batch, output)
70-
return {"loss": loss}
68+
return super().training_step(batch, batch_idx)
7169

7270
def train_dataloader(self):
7371
dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -859,8 +859,7 @@ def test_checkpoint_repeated_strategy(tmpdir):
859859

860860
class ExtendedBoringModel(BoringModel):
861861
def validation_step(self, batch, batch_idx):
862-
output = self.layer(batch)
863-
loss = self.loss(batch, output)
862+
loss = self.step(batch)
864863
self.log("val_loss", loss)
865864

866865
model = ExtendedBoringModel()
@@ -898,8 +897,7 @@ def test_checkpoint_repeated_strategy_extended(tmpdir):
898897

899898
class ExtendedBoringModel(BoringModel):
900899
def validation_step(self, batch, batch_idx):
901-
output = self.layer(batch)
902-
loss = self.loss(batch, output)
900+
loss = self.step(batch)
903901
self.log("val_loss", loss)
904902
return {"val_loss": loss}
905903

tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def configure_optimizers(self):
3737
return [optimizer], [lr_scheduler]
3838

3939
def validation_step(self, batch, batch_idx):
40-
output = self.layer(batch)
41-
loss = self.loss(batch, output)
40+
loss = self.step(batch)
4241
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
4342

4443
model = ExtendedBoringModel()

tests/tests_pytorch/core/test_lightning_optimizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ def training_step(self, batch, batch_idx):
8585
assert isinstance(opt_2, LightningOptimizer)
8686

8787
def closure(opt):
88-
output = self.layer(batch)
89-
loss = self.loss(batch, output)
88+
loss = self.step(batch)
9089
opt.zero_grad()
9190
self.manual_backward(loss)
9291

@@ -323,8 +322,7 @@ def test_lightning_optimizer_keeps_hooks(tmpdir):
323322
def test_params_groups_and_state_are_accessible(tmpdir):
324323
class TestModel(BoringModel):
325324
def training_step(self, batch, batch_idx, optimizer_idx):
326-
output = self.layer(batch)
327-
loss = self.loss(batch, output)
325+
loss = self.step(batch)
328326
self.__loss = loss
329327
return loss
330328

0 commit comments

Comments
 (0)