@@ -278,6 +278,9 @@ Doing it in the `prepare_data` method ensures that when you have
278278multiple GPUs you won't overwrite the data. This is a contrived example
279279but it gets more complicated with things like NLP or Imagenet.
280280
281+ `prepare_data ` gets called on the `LOCAL_RANK=0 ` GPU per node. If your nodes share a file system,
282+ set `Trainer(prepare_data_per_node=False) ` and it will be code from node=0, gpu=0 only.
283+
281284In general fill these methods with the following:
282285
283286.. testcode ::
@@ -535,16 +538,21 @@ will cause all sorts of issues.
535538To solve this problem, move the download code to the `prepare_data ` method in the LightningModule.
536539In this method we do all the preparation we need to do once (instead of on every gpu).
537540
541+ `prepare_data ` can be called in two ways, once per node or only on the root node (`Trainer(prepare_data_per_node=False) `).
542+
538543.. testcode ::
539544
540545 class LitMNIST(LightningModule):
541546 def prepare_data(self):
547+ # download only
548+ MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
549+ MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
550+
551+ def setup(self, stage):
542552 # transform
543553 transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
544-
545- # download
546- mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
547- mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
554+ MNIST(os.getcwd(), train=True, download=False, transform=transform)
555+ MNIST(os.getcwd(), train=False, download=False, transform=transform)
548556
549557 # train/val split
550558 mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
0 commit comments