Skip to content

Commit e5c122c

Browse files
jxtngxlantiga
authored andcommitted
Update docs/source-pytorch/common/lightning_module.rst (#18451)
Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit a013386)
1 parent ae2138d commit e5c122c

File tree

1 file changed

+125
-101
lines changed

1 file changed

+125
-101
lines changed

docs/source-pytorch/common/lightning_module.rst

Lines changed: 125 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -85,38 +85,42 @@ Here are the only required methods.
8585
.. code-block:: python
8686
8787
import lightning.pytorch as pl
88-
import torch.nn as nn
89-
import torch.nn.functional as F
88+
import torch
9089
90+
from lightning.pytorch.demos import Transformer
9191
92-
class LitModel(pl.LightningModule):
93-
def __init__(self):
92+
93+
class LightningTransformer(pl.LightningModule):
94+
def __init__(self, vocab_size):
9495
super().__init__()
95-
self.l1 = nn.Linear(28 * 28, 10)
96+
self.model = Transformer(vocab_size=vocab_size)
9697
97-
def forward(self, x):
98-
return torch.relu(self.l1(x.view(x.size(0), -1)))
98+
def forward(self, inputs, target):
99+
return self.model(inputs, target)
99100
100101
def training_step(self, batch, batch_idx):
101-
x, y = batch
102-
y_hat = self(x)
103-
loss = F.cross_entropy(y_hat, y)
102+
inputs, target = batch
103+
output = self(inputs, target)
104+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
104105
return loss
105106
106107
def configure_optimizers(self):
107-
return torch.optim.Adam(self.parameters(), lr=0.02)
108+
return torch.optim.SGD(self.model.parameters(), lr=0.1)
108109
109110
Which you can train by doing:
110111

111112
.. code-block:: python
112113
113-
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
114-
trainer = pl.Trainer(max_epochs=1)
115-
model = LitModel()
114+
from torch.utils.data import DataLoader
115+
116+
dataset = pl.demos.WikiText2()
117+
dataloader = DataLoader(dataset)
118+
model = LightningTransformer(vocab_size=dataset.vocab_size)
116119
117-
trainer.fit(model, train_dataloaders=train_loader)
120+
trainer = pl.Trainer(fast_dev_run=100)
121+
trainer.fit(model=model, train_dataloaders=dataloader)
118122
119-
The LightningModule has many convenience methods, but the core ones you need to know about are:
123+
The LightningModule has many convenient methods, but the core ones you need to know about are:
120124

121125
.. list-table::
122126
:widths: 50 50
@@ -152,15 +156,15 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul
152156

153157
.. code-block:: python
154158
155-
class LitClassifier(pl.LightningModule):
156-
def __init__(self, model):
159+
class LightningTransformer(pl.LightningModule):
160+
def __init__(self, vocab_size):
157161
super().__init__()
158-
self.model = model
162+
self.model = Transformer(vocab_size=vocab_size)
159163
160164
def training_step(self, batch, batch_idx):
161-
x, y = batch
162-
y_hat = self.model(x)
163-
loss = F.cross_entropy(y_hat, y)
165+
inputs, target = batch
166+
output = self.model(inputs, target)
167+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
164168
return loss
165169
166170
Under the hood, Lightning does the following (pseudocode):
@@ -191,15 +195,15 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning
191195

192196
.. code-block:: python
193197
194-
def training_step(self, batch, batch_idx):
195-
x, y = batch
196-
y_hat = self.model(x)
197-
loss = F.cross_entropy(y_hat, y)
198+
def training_step(self, batch, batch_idx):
199+
inputs, target = batch
200+
output = self.model(inputs, target)
201+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
198202
199-
# logs metrics for each training_step,
200-
# and the average across the epoch, to the progress bar and logger
201-
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
202-
return loss
203+
# logs metrics for each training_step,
204+
# and the average across the epoch, to the progress bar and logger
205+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
206+
return loss
203207
204208
The :meth:`~lightning.pytorch.core.module.LightningModule.log` method automatically reduces the
205209
requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:
@@ -230,25 +234,25 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho
230234

231235
.. code-block:: python
232236
233-
def __init__(self):
234-
super().__init__()
235-
self.training_step_outputs = []
236-
237-
238-
def training_step(self, batch, batch_idx):
239-
x, y = batch
240-
y_hat = self.model(x)
241-
loss = F.cross_entropy(y_hat, y)
242-
preds = ...
243-
self.training_step_outputs.append(preds)
244-
return loss
237+
class LightningTransformer(pl.LightningModule):
238+
def __init__(self, vocab_size):
239+
super().__init__()
240+
self.model = Transformer(vocab_size=vocab_size)
241+
self.training_step_outputs = []
245242
243+
def training_step(self, batch, batch_idx):
244+
inputs, target = batch
245+
output = self.model(inputs, target)
246+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
247+
preds = ...
248+
self.training_step_outputs.append(preds)
249+
return loss
246250
247-
def on_train_epoch_end(self):
248-
all_preds = torch.stack(self.training_step_outputs)
249-
# do something with all preds
250-
...
251-
self.training_step_outputs.clear() # free memory
251+
def on_train_epoch_end(self):
252+
all_preds = torch.stack(self.training_step_outputs)
253+
# do something with all preds
254+
...
255+
self.training_step_outputs.clear() # free memory
252256
253257
254258
------------------
@@ -264,10 +268,10 @@ To activate the validation loop while training, override the :meth:`~lightning.p
264268

265269
.. code-block:: python
266270
267-
class LitModel(pl.LightningModule):
271+
class LightningTransformer(pl.LightningModule):
268272
def validation_step(self, batch, batch_idx):
269-
x, y = batch
270-
y_hat = self.model(x)
273+
inputs, target = batch
274+
output = self.model(inputs, target)
271275
loss = F.cross_entropy(y_hat, y)
272276
self.log("val_loss", loss)
273277
@@ -300,8 +304,8 @@ and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`.
300304

301305
.. code-block:: python
302306
303-
model = Model()
304-
trainer = Trainer()
307+
model = LightningTransformer(vocab_size=dataset.vocab_size)
308+
trainer = pl.Trainer()
305309
trainer.validate(model)
306310
307311
.. note::
@@ -322,25 +326,26 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule
322326

323327
.. code-block:: python
324328
325-
def __init__(self):
326-
super().__init__()
327-
self.validation_step_outputs = []
328-
329-
330-
def validation_step(self, batch, batch_idx):
331-
x, y = batch
332-
y_hat = self.model(x)
333-
loss = F.cross_entropy(y_hat, y)
334-
pred = ...
335-
self.validation_step_outputs.append(pred)
336-
return pred
329+
class LightningTransformer(pl.LightningModule):
330+
def __init__(self, vocab_size):
331+
super().__init__()
332+
self.model = Transformer(vocab_size=vocab_size)
333+
self.validation_step_outputs = []
337334
335+
def validation_step(self, batch, batch_idx):
336+
x, y = batch
337+
inputs, target = batch
338+
output = self.model(inputs, target)
339+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
340+
pred = ...
341+
self.validation_step_outputs.append(pred)
342+
return pred
338343
339-
def on_validation_epoch_end(self):
340-
all_preds = torch.stack(self.validation_step_outputs)
341-
# do something with all preds
342-
...
343-
self.validation_step_outputs.clear() # free memory
344+
def on_validation_epoch_end(self):
345+
all_preds = torch.stack(self.validation_step_outputs)
346+
# do something with all preds
347+
...
348+
self.validation_step_outputs.clear() # free memory
344349
345350
----------------
346351

@@ -358,9 +363,10 @@ The only difference is that the test loop is only called when :meth:`~lightning.
358363

359364
.. code-block:: python
360365
361-
model = Model()
362-
trainer = Trainer()
363-
trainer.fit(model)
366+
model = LightningTransformer(vocab_size=dataset.vocab_size)
367+
dataloader = DataLoader(dataset)
368+
trainer = pl.Trainer()
369+
trainer.fit(model=model, train_dataloaders=dataloader)
364370
365371
# automatically loads the best weights for you
366372
trainer.test(model)
@@ -370,17 +376,23 @@ There are two ways to call ``test()``:
370376
.. code-block:: python
371377
372378
# call after training
373-
trainer = Trainer()
374-
trainer.fit(model)
379+
trainer = pl.Trainer()
380+
trainer.fit(model=model, train_dataloaders=dataloader)
375381
376382
# automatically auto-loads the best weights from the previous run
377-
trainer.test(dataloaders=test_dataloader)
383+
trainer.test(dataloaders=test_dataloaders)
378384
379385
# or call with pretrained model
380-
model = MyLightningModule.load_from_checkpoint(PATH)
381-
trainer = Trainer()
386+
model = LightningTransformer.load_from_checkpoint(PATH)
387+
dataset = WikiText2()
388+
test_dataloader = DataLoader(dataset)
389+
trainer = pl.Trainer()
382390
trainer.test(model, dataloaders=test_dataloader)
383391
392+
.. note::
393+
`WikiText2` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only.
394+
A proper split can be created in :meth:`lightning.pytorch.core.LightningModule.setup` or :meth:`lightning.pytorch.core.LightningDataModule.setup`.
395+
384396
.. note::
385397

386398
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
@@ -403,24 +415,18 @@ By default, the :meth:`~lightning.pytorch.core.module.LightningModule.predict_st
403415
:meth:`~lightning.pytorch.core.module.LightningModule.forward` method. In order to customize this behaviour,
404416
simply override the :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` method.
405417

406-
For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
418+
For the example let's override ``predict_step``:
407419

408420
.. code-block:: python
409421
410-
class LitMCdropoutModel(pl.LightningModule):
411-
def __init__(self, model, mc_iteration):
422+
class LightningTransformer(pl.LightningModule):
423+
def __init__(self, vocab_size):
412424
super().__init__()
413-
self.model = model
414-
self.dropout = nn.Dropout()
415-
self.mc_iteration = mc_iteration
416-
417-
def predict_step(self, batch, batch_idx):
418-
# enable Monte Carlo Dropout
419-
self.dropout.train()
425+
self.model = Transformer(vocab_size=vocab_size)
420426
421-
# take average of `self.mc_iteration` iterations
422-
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
423-
return pred
427+
def predict_step(self, batch):
428+
inputs, target = batch
429+
return self.model(inputs, target)
424430
425431
Under the hood, Lightning does the following (pseudocode):
426432

@@ -440,15 +446,17 @@ There are two ways to call ``predict()``:
440446
.. code-block:: python
441447
442448
# call after training
443-
trainer = Trainer()
444-
trainer.fit(model)
449+
trainer = pl.Trainer()
450+
trainer.fit(model=model, train_dataloaders=dataloader)
445451
446452
# automatically auto-loads the best weights from the previous run
447453
predictions = trainer.predict(dataloaders=predict_dataloader)
448454
449455
# or call with pretrained model
450-
model = MyLightningModule.load_from_checkpoint(PATH)
451-
trainer = Trainer()
456+
model = LightningTransformer.load_from_checkpoint(PATH)
457+
dataset = pl.demos.WikiText2()
458+
test_dataloader = DataLoader(dataset)
459+
trainer = pl.Trainer()
452460
predictions = trainer.predict(model, dataloaders=test_dataloader)
453461
454462
Inference in Research
@@ -460,15 +468,31 @@ If you want to perform inference with the system, you can add a ``forward`` meth
460468

461469
.. code-block:: python
462470
463-
class Autoencoder(pl.LightningModule):
464-
def forward(self, x):
465-
return self.decoder(x)
471+
class LightningTransformer(pl.LightningModule):
472+
def __init__(self, vocab_size):
473+
super().__init__()
474+
self.model = Transformer(vocab_size=vocab_size)
466475
476+
def forward(self, batch):
477+
inputs, target = batch
478+
return self.model(inputs, target)
479+
480+
def training_step(self, batch, batch_idx):
481+
inputs, target = batch
482+
output = self.model(inputs, target)
483+
loss = torch.nn.functional.nll_loss(output, target.view(-1))
484+
return loss
485+
486+
def configure_optimizers(self):
487+
return torch.optim.SGD(self.model.parameters(), lr=0.1)
488+
489+
490+
model = LightningTransformer(vocab_size=dataset.vocab_size)
467491
468-
model = Autoencoder()
469492
model.eval()
470493
with torch.no_grad():
471-
reconstruction = model(embedding)
494+
batch = dataloader.dataset[0]
495+
pred = model(batch)
472496
473497
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
474498
such as text generation:
@@ -618,7 +642,7 @@ checkpoint, which simplifies model re-instantiation after training.
618642

619643
.. code-block:: python
620644
621-
class LitMNIST(LightningModule):
645+
class LitMNIST(pl.LightningModule):
622646
def __init__(self, layer_1_dim=128, learning_rate=1e-2):
623647
super().__init__()
624648
# call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
@@ -642,7 +666,7 @@ parameters should be provided back when reloading the LightningModule. In this c
642666

643667
.. code-block:: python
644668
645-
class LitMNIST(LightningModule):
669+
class LitMNIST(pl.LightningModule):
646670
def __init__(self, loss_fx, generator_network, layer_1_dim=128):
647671
super().__init__()
648672
self.layer_1_dim = layer_1_dim

0 commit comments

Comments
 (0)