Skip to content

Commit 9f140b7

Browse files
williamFalconBorda
authored andcommitted
updated test (#1073)
1 parent ff1f8ef commit 9f140b7

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

docs/source/child_modules.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Child Modules
33
Research projects tend to test different approaches to the same dataset.
44
This is very easy to do in Lightning with inheritance.
55

6-
For example, imaging we now want to train an Autoencoder to use as a feature extractor for MNIST images.
6+
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
77
Recall that `LitMNIST` already defines all the dataloading etc... The only things
88
that change in the `Autoencoder` model are the init, forward, training, validation and test step.
99

docs/source/introduction_guide.rst

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,50 @@ Notice the code is exactly the same, except now the training dataloading has bee
212212
under the `train_dataloader` method. This is great because if you run into a project that uses Lightning and want
213213
to figure out how they prepare their training data you can just look in the `train_dataloader` method.
214214

215+
Usually though, we want to separate the things that write to disk in data-processing from
216+
things like transforms which happen in memory.
217+
218+
.. code-block:: python
219+
220+
class LitMNIST(pl.LightningModule):
221+
222+
def prepare_data(self):
223+
# download only
224+
MNIST(os.getcwd(), train=True, download=True)
225+
226+
def train_dataloader(self):
227+
# no download, just transform
228+
transform=transforms.Compose([transforms.ToTensor(),
229+
transforms.Normalize((0.1307,), (0.3081,))])
230+
mnist_train = MNIST(os.getcwd(), train=True, download=False,
231+
transform=transform)
232+
return DataLoader(mnist_train, batch_size=64)
233+
234+
Doing it in the `prepare_data` method ensures that when you have
235+
multiple GPUs you won't overwrite the data. This is a contrived example
236+
but it gets more complicated with things like NLP or Imagenet.
237+
238+
In general fill these methods with the following:
239+
240+
.. code-block:: python
241+
242+
class LitMNIST(pl.LightningModule):
243+
244+
def prepare_data(self):
245+
# stuff here is done once at the very beginning of training
246+
# before any distributed training starts
247+
248+
# download stuff
249+
# save to disk
250+
# etc...
251+
252+
def train_dataloader(self):
253+
# data transforms
254+
# dataset creation
255+
# return a DataLoader
256+
257+
258+
215259
Optimizer
216260
^^^^^^^^^
217261

@@ -606,11 +650,11 @@ metrics we care about, generate samples or add more to our logs.
606650
loss = loss(y_hat, x) # validation_step
607651
outputs.append({'val_loss': loss}) # validation_step
608652
609-
full_loss = outputs.mean() # validation_end
653+
full_loss = outputs.mean() # validation_epoch_end
610654
611655
Since the `validation_step` processes a single batch,
612-
in Lightning we also have a `validation_end` method which allows you to compute
613-
statistics on the full dataset and not just the batch.
656+
in Lightning we also have a `validation_epoch_end` method which allows you to compute
657+
statistics on the full dataset after an epoch of validation data and not just the batch.
614658

615659
In addition, we define a `val_dataloader` method which tells the trainer what data to use for validation.
616660
Notice we split the train split of MNIST into train, validation. We also have to make sure to do the
@@ -640,7 +684,7 @@ sample split in the `train_dataloader` method.
640684
return mnist_val
641685
642686
Again, we've just organized the regular PyTorch code into two steps, the `validation_step` method which
643-
operates on a single batch and the `validation_end` method to compute statistics on all batches.
687+
operates on a single batch and the `validation_epoch_end` method to compute statistics on all batches.
644688

645689
If you have these methods defined, Lightning will call them automatically. Now we can train
646690
while checking the validation set.
@@ -669,7 +713,7 @@ how it will generalize in the "real world." For this, we use a held-out split of
669713
Just like the validation loop, we define exactly the same steps for testing:
670714

671715
- test_step
672-
- test_end
716+
- test_epoch_end
673717
- test_dataloader
674718

675719
.. code-block:: python
@@ -707,6 +751,17 @@ Once you train your model simply call `.test()`.
707751
# run test set
708752
trainer.test()
709753
754+
.. rst-class:: sphx-glr-script-out
755+
756+
Out:
757+
758+
.. code-block:: none
759+
760+
--------------------------------------------------------------
761+
TEST RESULTS
762+
{'test_loss': tensor(1.1703, device='cuda:0')}
763+
--------------------------------------------------------------
764+
710765
You can also run the test from a saved lightning model
711766

712767
.. code-block:: python
@@ -881,6 +936,7 @@ you could do your own:
881936
Every single part of training is configurable this way.
882937
For a full list look at `lightningModule <lightning-module.rst>`_.
883938

939+
---------
884940

885941
Callbacks
886942
---------

0 commit comments

Comments
 (0)