@@ -212,6 +212,50 @@ Notice the code is exactly the same, except now the training dataloading has bee
212212under the `train_dataloader ` method. This is great because if you run into a project that uses Lightning and want
213213to figure out how they prepare their training data you can just look in the `train_dataloader ` method.
214214
215+ Usually though, we want to separate the things that write to disk in data-processing from
216+ things like transforms which happen in memory.
217+
218+ .. code-block :: python
219+
220+ class LitMNIST (pl .LightningModule ):
221+
222+ def prepare_data (self ):
223+ # download only
224+ MNIST(os.getcwd(), train = True , download = True )
225+
226+ def train_dataloader (self ):
227+ # no download, just transform
228+ transform= transforms.Compose([transforms.ToTensor(),
229+ transforms.Normalize((0.1307 ,), (0.3081 ,))])
230+ mnist_train = MNIST(os.getcwd(), train = True , download = False ,
231+ transform = transform)
232+ return DataLoader(mnist_train, batch_size = 64 )
233+
234+ Doing it in the `prepare_data ` method ensures that when you have
235+ multiple GPUs you won't overwrite the data. This is a contrived example
236+ but it gets more complicated with things like NLP or Imagenet.
237+
238+ In general fill these methods with the following:
239+
240+ .. code-block :: python
241+
242+ class LitMNIST (pl .LightningModule ):
243+
244+ def prepare_data (self ):
245+ # stuff here is done once at the very beginning of training
246+ # before any distributed training starts
247+
248+ # download stuff
249+ # save to disk
250+ # etc...
251+
252+ def train_dataloader (self ):
253+ # data transforms
254+ # dataset creation
255+ # return a DataLoader
256+
257+
258+
215259 Optimizer
216260^^^^^^^^^
217261
@@ -606,11 +650,11 @@ metrics we care about, generate samples or add more to our logs.
606650 loss = loss(y_hat, x) # validation_step
607651 outputs.append({' val_loss' : loss}) # validation_step
608652
609- full_loss = outputs.mean() # validation_end
653+ full_loss = outputs.mean() # validation_epoch_end
610654
611655 Since the `validation_step ` processes a single batch,
612- in Lightning we also have a `validation_end ` method which allows you to compute
613- statistics on the full dataset and not just the batch.
656+ in Lightning we also have a `validation_epoch_end ` method which allows you to compute
657+ statistics on the full dataset after an epoch of validation data and not just the batch.
614658
615659In addition, we define a `val_dataloader ` method which tells the trainer what data to use for validation.
616660Notice we split the train split of MNIST into train, validation. We also have to make sure to do the
@@ -640,7 +684,7 @@ sample split in the `train_dataloader` method.
640684 return mnist_val
641685
642686 Again, we've just organized the regular PyTorch code into two steps, the `validation_step ` method which
643- operates on a single batch and the `validation_end ` method to compute statistics on all batches.
687+ operates on a single batch and the `validation_epoch_end ` method to compute statistics on all batches.
644688
645689If you have these methods defined, Lightning will call them automatically. Now we can train
646690while checking the validation set.
@@ -669,7 +713,7 @@ how it will generalize in the "real world." For this, we use a held-out split of
669713Just like the validation loop, we define exactly the same steps for testing:
670714
671715- test_step
672- - test_end
716+ - test_epoch_end
673717- test_dataloader
674718
675719.. code-block :: python
@@ -707,6 +751,17 @@ Once you train your model simply call `.test()`.
707751 # run test set
708752 trainer.test()
709753
754+ .. rst-class :: sphx-glr-script-out
755+
756+ Out:
757+
758+ .. code-block :: none
759+
760+ --------------------------------------------------------------
761+ TEST RESULTS
762+ {'test_loss': tensor(1.1703, device='cuda:0')}
763+ --------------------------------------------------------------
764+
710765 You can also run the test from a saved lightning model
711766
712767.. code-block :: python
@@ -881,6 +936,7 @@ you could do your own:
881936 Every single part of training is configurable this way.
882937For a full list look at `lightningModule <lightning-module.rst >`_.
883938
939+ ---------
884940
885941Callbacks
886942---------
0 commit comments