Skip to content

Commit 79e1426

Browse files
Docs clean-up (#2234)
* update docs * update docs * update docs * update docs * update docs * update docs
1 parent a2d3ee8 commit 79e1426

File tree

8 files changed

+30
-18
lines changed

8 files changed

+30
-18
lines changed

docs/source/introduction_guide.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ Doing it in the `prepare_data` method ensures that when you have
278278
multiple GPUs you won't overwrite the data. This is a contrived example
279279
but it gets more complicated with things like NLP or Imagenet.
280280

281+
`prepare_data` gets called on the `LOCAL_RANK=0` GPU per node. If your nodes share a file system,
282+
set `Trainer(prepare_data_per_node=False)` and it will be code from node=0, gpu=0 only.
283+
281284
In general fill these methods with the following:
282285

283286
.. testcode::
@@ -535,16 +538,21 @@ will cause all sorts of issues.
535538
To solve this problem, move the download code to the `prepare_data` method in the LightningModule.
536539
In this method we do all the preparation we need to do once (instead of on every gpu).
537540

541+
`prepare_data` can be called in two ways, once per node or only on the root node (`Trainer(prepare_data_per_node=False)`).
542+
538543
.. testcode::
539544

540545
class LitMNIST(LightningModule):
541546
def prepare_data(self):
547+
# download only
548+
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
549+
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
550+
551+
def setup(self, stage):
542552
# transform
543553
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
544-
545-
# download
546-
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
547-
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
554+
MNIST(os.getcwd(), train=True, download=False, transform=transform)
555+
MNIST(os.getcwd(), train=False, download=False, transform=transform)
548556

549557
# train/val split
550558
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,11 @@ def configure_optimizers(self):
307307

308308
def prepare_data(self):
309309
"""Download images and prepare images datasets."""
310-
311-
# 1. Download the images
312310
download_and_extract_archive(url=DATA_URL,
313311
download_root=self.dl_path,
314312
remove_finished=True)
315313

314+
def setup(self, stage: str):
316315
data_path = Path(self.dl_path).joinpath('cats_and_dogs_filtered')
317316

318317
# 2. Load the data + preprocessing & data augmentation

pl_examples/models/lightning_template.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,13 @@ def configure_optimizers(self):
141141
return [optimizer], [scheduler]
142142

143143
def prepare_data(self):
144-
transform = transforms.Compose([transforms.ToTensor(),
145-
transforms.Normalize((0.5,), (1.0,))])
146-
self.mnist_train = MNIST(self.data_root, train=True, download=True, transform=transform)
147-
self.mnist_test = MNIST(self.data_root, train=False, download=True, transform=transform)
144+
MNIST(self.data_root, train=True, download=True, transform=transforms.ToTensor())
145+
MNIST(self.data_root, train=False, download=True, transform=transforms.ToTensor())
146+
147+
def setup(self, stage):
148+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
149+
self.mnist_train = MNIST(self.data_root, train=True, download=False, transform=transform)
150+
self.mnist_test = MNIST(self.data_root, train=False, download=False, transform=transform)
148151

149152
def train_dataloader(self):
150153
log.info('Training data loader called.')

pytorch_lightning/core/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,12 @@ def training_step(self, batch, batch_idx):
267267
>>> class LitModel(pl.LightningModule):
268268
... def prepare_data(self):
269269
... # download
270-
... mnist_train = MNIST(os.getcwd(), train=True, download=True,
271-
... transform=transforms.ToTensor())
272-
... mnist_test = MNIST(os.getcwd(), train=False, download=True,
273-
... transform=transforms.ToTensor())
270+
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
271+
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
274272
...
273+
... def setup(self, stage):
274+
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
275+
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
275276
... # train/val split
276277
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
277278
...

pytorch_lightning/core/hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def setup(self, stage: str):
2222
Called at the beginning of fit and test.
2323
2424
Args:
25-
step: either 'fit' or 'test'
25+
stage: either 'fit' or 'test'
2626
"""
2727

2828
def teardown(self, stage: str):

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,8 @@ def tbptt_split_batch(self, batch, split_size):
12901290
def prepare_data(self) -> None:
12911291
"""
12921292
Use this to download and prepare data.
1293-
In distributed (GPU, TPU), this will only be called once.
1293+
In distributed (GPU, TPU), this will only be called once on the local_rank=0 of each node.
1294+
To call this on only the root=0 of the main node, use `Trainer(prepare_data_per_node=False)`
12941295
This is called before requesting the dataloaders:
12951296
12961297
.. code-block:: python

tests/base/model_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def loss(self, labels, logits):
104104
return nll
105105

106106
def prepare_data(self):
107-
_ = TrialMNIST(root=self.data_root, train=True, download=True)
107+
TrialMNIST(root=self.data_root, train=True, download=True)
108108

109109
@staticmethod
110110
def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> dict:

tests/callbacks/test_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self):
6969
self.on_test_start_called = False
7070
self.on_test_end_called = False
7171

72-
def setup(self, trainer, step: str):
72+
def setup(self, trainer, stage: str):
7373
assert isinstance(trainer, Trainer)
7474
self.setup_called = True
7575

0 commit comments

Comments
 (0)