@@ -85,38 +85,42 @@ Here are the only required methods.
8585.. code-block :: python
8686
8787 import lightning.pytorch as pl
88- import torch.nn as nn
89- import torch.nn.functional as F
88+ import torch
9089
90+ from lightning.pytorch.demos import Transformer
9191
92- class LitModel (pl .LightningModule ):
93- def __init__ (self ):
92+
93+ class LightningTransformer (pl .LightningModule ):
94+ def __init__ (self , vocab_size ):
9495 super ().__init__ ()
95- self .l1 = nn.Linear( 28 * 28 , 10 )
96+ self .model = Transformer( vocab_size = vocab_size )
9697
97- def forward (self , x ):
98- return torch.relu( self .l1(x.view(x.size( 0 ), - 1 )) )
98+ def forward (self , inputs , target ):
99+ return self .model(inputs, target )
99100
100101 def training_step (self , batch , batch_idx ):
101- x, y = batch
102- y_hat = self (x )
103- loss = F.cross_entropy(y_hat, y )
102+ inputs, target = batch
103+ output = self (inputs, target )
104+ loss = torch.nn.functional.nll_loss(output, target.view( - 1 ) )
104105 return loss
105106
106107 def configure_optimizers (self ):
107- return torch.optim.Adam (self .parameters(), lr = 0.02 )
108+ return torch.optim.SGD (self .model. parameters(), lr = 0.1 )
108109
109110 Which you can train by doing:
110111
111112.. code-block :: python
112113
113- train_loader = DataLoader(MNIST(os.getcwd(), download = True , transform = transforms.ToTensor()))
114- trainer = pl.Trainer(max_epochs = 1 )
115- model = LitModel()
114+ from torch.utils.data import DataLoader
115+
116+ dataset = pl.demos.WikiText2()
117+ dataloader = DataLoader(dataset)
118+ model = LightningTransformer(vocab_size = dataset.vocab_size)
116119
117- trainer.fit(model, train_dataloaders = train_loader)
120+ trainer = pl.Trainer(fast_dev_run = 100 )
121+ trainer.fit(model = model, train_dataloaders = dataloader)
118122
119- The LightningModule has many convenience methods, but the core ones you need to know about are:
123+ The LightningModule has many convenient methods, but the core ones you need to know about are:
120124
121125.. list-table ::
122126 :widths: 50 50
@@ -152,15 +156,15 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul
152156
153157.. code-block :: python
154158
155- class LitClassifier (pl .LightningModule ):
156- def __init__ (self , model ):
159+ class LightningTransformer (pl .LightningModule ):
160+ def __init__ (self , vocab_size ):
157161 super ().__init__ ()
158- self .model = model
162+ self .model = Transformer( vocab_size = vocab_size)
159163
160164 def training_step (self , batch , batch_idx ):
161- x, y = batch
162- y_hat = self .model(x )
163- loss = F.cross_entropy(y_hat, y )
165+ inputs, target = batch
166+ output = self .model(inputs, target )
167+ loss = torch.nn.functional.nll_loss(output, target.view( - 1 ) )
164168 return loss
165169
166170 Under the hood, Lightning does the following (pseudocode):
@@ -191,15 +195,15 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning
191195
192196.. code-block :: python
193197
194- def training_step (self , batch , batch_idx ):
195- x, y = batch
196- y_hat = self .model(x )
197- loss = F.cross_entropy(y_hat, y )
198+ def training_step (self , batch , batch_idx ):
199+ inputs, target = batch
200+ output = self .model(inputs, target )
201+ loss = torch.nn.functional.nll_loss(output, target.view( - 1 ) )
198202
199- # logs metrics for each training_step,
200- # and the average across the epoch, to the progress bar and logger
201- self .log(" train_loss" , loss, on_step = True , on_epoch = True , prog_bar = True , logger = True )
202- return loss
203+ # logs metrics for each training_step,
204+ # and the average across the epoch, to the progress bar and logger
205+ self .log(" train_loss" , loss, on_step = True , on_epoch = True , prog_bar = True , logger = True )
206+ return loss
203207
204208 The :meth: `~lightning.pytorch.core.module.LightningModule.log ` method automatically reduces the
205209requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:
@@ -230,25 +234,25 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho
230234
231235.. code-block :: python
232236
233- def __init__ (self ):
234- super ().__init__ ()
235- self .training_step_outputs = []
236-
237-
238- def training_step (self , batch , batch_idx ):
239- x, y = batch
240- y_hat = self .model(x)
241- loss = F.cross_entropy(y_hat, y)
242- preds = ...
243- self .training_step_outputs.append(preds)
244- return loss
237+ class LightningTransformer (pl .LightningModule ):
238+ def __init__ (self , vocab_size ):
239+ super ().__init__ ()
240+ self .model = Transformer(vocab_size = vocab_size)
241+ self .training_step_outputs = []
245242
243+ def training_step (self , batch , batch_idx ):
244+ inputs, target = batch
245+ output = self .model(inputs, target)
246+ loss = torch.nn.functional.nll_loss(output, target.view(- 1 ))
247+ preds = ...
248+ self .training_step_outputs.append(preds)
249+ return loss
246250
247- def on_train_epoch_end (self ):
248- all_preds = torch.stack(self .training_step_outputs)
249- # do something with all preds
250- ...
251- self .training_step_outputs.clear() # free memory
251+ def on_train_epoch_end (self ):
252+ all_preds = torch.stack(self .training_step_outputs)
253+ # do something with all preds
254+ ...
255+ self .training_step_outputs.clear() # free memory
252256
253257
254258------------------
@@ -264,10 +268,10 @@ To activate the validation loop while training, override the :meth:`~lightning.p
264268
265269.. code-block :: python
266270
267- class LitModel (pl .LightningModule ):
271+ class LightningTransformer (pl .LightningModule ):
268272 def validation_step (self , batch , batch_idx ):
269- x, y = batch
270- y_hat = self .model(x )
273+ inputs, target = batch
274+ output = self .model(inputs, target )
271275 loss = F.cross_entropy(y_hat, y)
272276 self .log(" val_loss" , loss)
273277
@@ -300,8 +304,8 @@ and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`.
300304
301305.. code-block :: python
302306
303- model = Model( )
304- trainer = Trainer()
307+ model = LightningTransformer( vocab_size = dataset.vocab_size )
308+ trainer = pl. Trainer()
305309 trainer.validate(model)
306310
307311 .. note ::
@@ -322,25 +326,26 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule
322326
323327.. code-block :: python
324328
325- def __init__ (self ):
326- super ().__init__ ()
327- self .validation_step_outputs = []
328-
329-
330- def validation_step (self , batch , batch_idx ):
331- x, y = batch
332- y_hat = self .model(x)
333- loss = F.cross_entropy(y_hat, y)
334- pred = ...
335- self .validation_step_outputs.append(pred)
336- return pred
329+ class LightningTransformer (pl .LightningModule ):
330+ def __init__ (self , vocab_size ):
331+ super ().__init__ ()
332+ self .model = Transformer(vocab_size = vocab_size)
333+ self .validation_step_outputs = []
337334
335+ def validation_step (self , batch , batch_idx ):
336+ x, y = batch
337+ inputs, target = batch
338+ output = self .model(inputs, target)
339+ loss = torch.nn.functional.nll_loss(output, target.view(- 1 ))
340+ pred = ...
341+ self .validation_step_outputs.append(pred)
342+ return pred
338343
339- def on_validation_epoch_end (self ):
340- all_preds = torch.stack(self .validation_step_outputs)
341- # do something with all preds
342- ...
343- self .validation_step_outputs.clear() # free memory
344+ def on_validation_epoch_end (self ):
345+ all_preds = torch.stack(self .validation_step_outputs)
346+ # do something with all preds
347+ ...
348+ self .validation_step_outputs.clear() # free memory
344349
345350----------------
346351
@@ -358,9 +363,10 @@ The only difference is that the test loop is only called when :meth:`~lightning.
358363
359364.. code-block :: python
360365
361- model = Model()
362- trainer = Trainer()
363- trainer.fit(model)
366+ model = LightningTransformer(vocab_size = dataset.vocab_size)
367+ dataloader = DataLoader(dataset)
368+ trainer = pl.Trainer()
369+ trainer.fit(model = model, train_dataloaders = dataloader)
364370
365371 # automatically loads the best weights for you
366372 trainer.test(model)
@@ -370,17 +376,23 @@ There are two ways to call ``test()``:
370376.. code-block :: python
371377
372378 # call after training
373- trainer = Trainer()
374- trainer.fit(model)
379+ trainer = pl. Trainer()
380+ trainer.fit(model = model, train_dataloaders = dataloader )
375381
376382 # automatically auto-loads the best weights from the previous run
377- trainer.test(dataloaders = test_dataloader )
383+ trainer.test(dataloaders = test_dataloaders )
378384
379385 # or call with pretrained model
380- model = MyLightningModule.load_from_checkpoint(PATH )
381- trainer = Trainer()
386+ model = LightningTransformer.load_from_checkpoint(PATH )
387+ dataset = WikiText2()
388+ test_dataloader = DataLoader(dataset)
389+ trainer = pl.Trainer()
382390 trainer.test(model, dataloaders = test_dataloader)
383391
392+ .. note ::
393+ `WikiText2 ` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only.
394+ A proper split can be created in :meth: `lightning.pytorch.core.LightningModule.setup ` or :meth: `lightning.pytorch.core.LightningDataModule.setup `.
395+
384396.. note ::
385397
386398 It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
@@ -403,24 +415,18 @@ By default, the :meth:`~lightning.pytorch.core.module.LightningModule.predict_st
403415:meth: `~lightning.pytorch.core.module.LightningModule.forward ` method. In order to customize this behaviour,
404416simply override the :meth: `~lightning.pytorch.core.module.LightningModule.predict_step ` method.
405417
406- For the example let's override ``predict_step `` and try out ` Monte Carlo Dropout < https://arxiv.org/pdf/1506.02142.pdf >`_ :
418+ For the example let's override ``predict_step ``:
407419
408420.. code-block :: python
409421
410- class LitMCdropoutModel (pl .LightningModule ):
411- def __init__ (self , model , mc_iteration ):
422+ class LightningTransformer (pl .LightningModule ):
423+ def __init__ (self , vocab_size ):
412424 super ().__init__ ()
413- self .model = model
414- self .dropout = nn.Dropout()
415- self .mc_iteration = mc_iteration
416-
417- def predict_step (self , batch , batch_idx ):
418- # enable Monte Carlo Dropout
419- self .dropout.train()
425+ self .model = Transformer(vocab_size = vocab_size)
420426
421- # take average of ` self.mc_iteration` iterations
422- pred = torch.vstack([ self .dropout( self .model(x)).unsqueeze( 0 ) for _ in range ( self .mc_iteration)]).mean( dim = 0 )
423- return pred
427+ def predict_step ( self , batch ):
428+ inputs, target = batch
429+ return self .model(inputs, target)
424430
425431 Under the hood, Lightning does the following (pseudocode):
426432
@@ -440,15 +446,17 @@ There are two ways to call ``predict()``:
440446.. code-block :: python
441447
442448 # call after training
443- trainer = Trainer()
444- trainer.fit(model)
449+ trainer = pl. Trainer()
450+ trainer.fit(model = model, train_dataloaders = dataloader )
445451
446452 # automatically auto-loads the best weights from the previous run
447453 predictions = trainer.predict(dataloaders = predict_dataloader)
448454
449455 # or call with pretrained model
450- model = MyLightningModule.load_from_checkpoint(PATH )
451- trainer = Trainer()
456+ model = LightningTransformer.load_from_checkpoint(PATH )
457+ dataset = pl.demos.WikiText2()
458+ test_dataloader = DataLoader(dataset)
459+ trainer = pl.Trainer()
452460 predictions = trainer.predict(model, dataloaders = test_dataloader)
453461
454462 Inference in Research
@@ -460,15 +468,31 @@ If you want to perform inference with the system, you can add a ``forward`` meth
460468
461469.. code-block :: python
462470
463- class Autoencoder (pl .LightningModule ):
464- def forward (self , x ):
465- return self .decoder(x)
471+ class LightningTransformer (pl .LightningModule ):
472+ def __init__ (self , vocab_size ):
473+ super ().__init__ ()
474+ self .model = Transformer(vocab_size = vocab_size)
466475
476+ def forward (self , batch ):
477+ inputs, target = batch
478+ return self .model(inputs, target)
479+
480+ def training_step (self , batch , batch_idx ):
481+ inputs, target = batch
482+ output = self .model(inputs, target)
483+ loss = torch.nn.functional.nll_loss(output, target.view(- 1 ))
484+ return loss
485+
486+ def configure_optimizers (self ):
487+ return torch.optim.SGD(self .model.parameters(), lr = 0.1 )
488+
489+
490+ model = LightningTransformer(vocab_size = dataset.vocab_size)
467491
468- model = Autoencoder()
469492 model.eval()
470493 with torch.no_grad():
471- reconstruction = model(embedding)
494+ batch = dataloader.dataset[0 ]
495+ pred = model(batch)
472496
473497 The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
474498such as text generation:
@@ -618,7 +642,7 @@ checkpoint, which simplifies model re-instantiation after training.
618642
619643.. code-block :: python
620644
621- class LitMNIST (LightningModule ):
645+ class LitMNIST (pl . LightningModule ):
622646 def __init__ (self , layer_1_dim = 128 , learning_rate = 1e-2 ):
623647 super ().__init__ ()
624648 # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
@@ -642,7 +666,7 @@ parameters should be provided back when reloading the LightningModule. In this c
642666
643667.. code-block :: python
644668
645- class LitMNIST (LightningModule ):
669+ class LitMNIST (pl . LightningModule ):
646670 def __init__ (self , loss_fx , generator_network , layer_1_dim = 128 ):
647671 super ().__init__ ()
648672 self .layer_1_dim = layer_1_dim
0 commit comments