Skip to content

Commit 5bb5b14

Browse files
BordaSeanNaren
andcommitted
test PL examples (#4551)
* test PL examples * minor formatting * skip failing * skip failing * args * mnist datamodule * refactor tests * refactor tests * skip * skip * drop DM * drop DM Co-authored-by: Sean Naren <[email protected]>
1 parent 65cad1a commit 5bb5b14

File tree

5 files changed

+274
-70
lines changed

5 files changed

+274
-70
lines changed

.drone.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ steps:
3434
- pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir
3535
- pip list
3636
- coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --color=yes --durations=25 # --flake8
37-
- python -m pytest benchmarks pl_examples -v --color=yes --maxfail=2 --durations=0 # --flake8
37+
- python -m pytest benchmarks pl_examples -v --color=yes --maxfail=7 --durations=0 # --flake8
3838
#- cd docs; make doctest; make coverage
3939
- coverage report
4040
# see: https://docs.codecov.io/docs/merging-reports

docs/source/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def package_list_from_file(file):
316316
if SPHINX_MOCK_REQUIREMENTS:
317317
# mock also base packages when we are on RTD since we don't install them there
318318
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt'))
319-
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements/extra.txt'))
320-
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements/loggers.txt'))
319+
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'extra.txt'))
320+
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'loggers.txt'))
321321
MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]
322322

323323
autodoc_mock_imports = MOCK_PACKAGES

pl_examples/basic_examples/mnist.py renamed to pl_examples/basic_examples/mnist_classifier.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from argparse import ArgumentParser
1515

1616
import torch
17-
import pytorch_lightning as pl
1817
from torch.nn import functional as F
1918
from torch.utils.data import DataLoader, random_split
2019

20+
import pytorch_lightning as pl
21+
2122
try:
22-
from torchvision.datasets.mnist import MNIST
2323
from torchvision import transforms
24+
from torchvision.datasets.mnist import MNIST
2425
except Exception as e:
2526
from tests.base.datasets import MNIST
2627

@@ -105,7 +106,7 @@ def cli_main():
105106
# ------------
106107
# testing
107108
# ------------
108-
trainer.test(test_dataloaders=test_loader)
109+
result = trainer.test(test_dataloaders=test_loader)
109110

110111

111112
if __name__ == '__main__':
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC
15+
from argparse import ArgumentParser
16+
from random import shuffle
17+
from warnings import warn
18+
19+
import numpy as np
20+
import torch
21+
from torch.nn import functional as F
22+
from torch.utils.data import random_split
23+
24+
import pytorch_lightning as pl
25+
26+
try:
27+
from torchvision import transforms
28+
from torchvision.datasets.mnist import MNIST
29+
except Exception:
30+
from tests.base.datasets import MNIST
31+
32+
try:
33+
import nvidia.dali.ops as ops
34+
import nvidia.dali.types as types
35+
from nvidia.dali.pipeline import Pipeline
36+
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
37+
except (ImportError, ModuleNotFoundError):
38+
warn('NVIDIA DALI is not available')
39+
ops, types, Pipeline, DALIClassificationIterator = ..., ..., ABC, ABC
40+
41+
42+
class ExternalMNISTInputIterator(object):
43+
"""
44+
This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches
45+
"""
46+
47+
def __init__(self, mnist_ds, batch_size):
48+
self.batch_size = batch_size
49+
self.mnist_ds = mnist_ds
50+
self.indices = list(range(len(self.mnist_ds)))
51+
shuffle(self.indices)
52+
53+
def __iter__(self):
54+
self.i = 0
55+
self.n = len(self.mnist_ds)
56+
return self
57+
58+
def __next__(self):
59+
batch = []
60+
labels = []
61+
for _ in range(self.batch_size):
62+
index = self.indices[self.i]
63+
img, label = self.mnist_ds[index]
64+
batch.append(img.numpy())
65+
labels.append(np.array([label], dtype=np.uint8))
66+
self.i = (self.i + 1) % self.n
67+
return (batch, labels)
68+
69+
70+
class ExternalSourcePipeline(Pipeline):
71+
"""
72+
This DALI pipeline class just contains the MNIST iterator
73+
"""
74+
75+
def __init__(self, batch_size, eii, num_threads, device_id):
76+
super(ExternalSourcePipeline, self).__init__(batch_size, num_threads, device_id, seed=12)
77+
self.source = ops.ExternalSource(source=eii, num_outputs=2)
78+
self.build()
79+
80+
def define_graph(self):
81+
images, labels = self.source()
82+
return images, labels
83+
84+
85+
class DALIClassificationLoader(DALIClassificationIterator):
86+
"""
87+
This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it
88+
"""
89+
90+
def __init__(
91+
self,
92+
pipelines,
93+
size=-1,
94+
reader_name=None,
95+
auto_reset=False,
96+
fill_last_batch=True,
97+
dynamic_shape=False,
98+
last_batch_padded=False,
99+
):
100+
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded)
101+
102+
def __len__(self):
103+
batch_count = self._size // (self._num_gpus * self.batch_size)
104+
last_batch = 1 if self._fill_last_batch else 0
105+
return batch_count + last_batch
106+
107+
108+
class LitClassifier(pl.LightningModule):
109+
def __init__(self, hidden_dim=128, learning_rate=1e-3):
110+
super().__init__()
111+
self.save_hyperparameters()
112+
113+
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
114+
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
115+
116+
def forward(self, x):
117+
x = x.view(x.size(0), -1)
118+
x = torch.relu(self.l1(x))
119+
x = torch.relu(self.l2(x))
120+
return x
121+
122+
def split_batch(self, batch):
123+
return batch[0]["data"], batch[0]["label"].squeeze().long()
124+
125+
def training_step(self, batch, batch_idx):
126+
x, y = self.split_batch(batch)
127+
y_hat = self(x)
128+
loss = F.cross_entropy(y_hat, y)
129+
return loss
130+
131+
def validation_step(self, batch, batch_idx):
132+
x, y = self.split_batch(batch)
133+
y_hat = self(x)
134+
loss = F.cross_entropy(y_hat, y)
135+
self.log('valid_loss', loss)
136+
137+
def test_step(self, batch, batch_idx):
138+
x, y = self.split_batch(batch)
139+
y_hat = self(x)
140+
loss = F.cross_entropy(y_hat, y)
141+
self.log('test_loss', loss)
142+
143+
def configure_optimizers(self):
144+
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
145+
146+
@staticmethod
147+
def add_model_specific_args(parent_parser):
148+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
149+
parser.add_argument('--hidden_dim', type=int, default=128)
150+
parser.add_argument('--learning_rate', type=float, default=0.0001)
151+
return parser
152+
153+
154+
def cli_main():
155+
pl.seed_everything(1234)
156+
157+
# ------------
158+
# args
159+
# ------------
160+
parser = ArgumentParser()
161+
parser.add_argument('--batch_size', default=32, type=int)
162+
parser = pl.Trainer.add_argparse_args(parser)
163+
parser = LitClassifier.add_model_specific_args(parser)
164+
args = parser.parse_args()
165+
166+
# ------------
167+
# data
168+
# ------------
169+
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
170+
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
171+
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
172+
173+
eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size)
174+
eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size)
175+
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)
176+
177+
pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
178+
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)
179+
180+
pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
181+
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)
182+
183+
pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0)
184+
test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False)
185+
186+
# ------------
187+
# model
188+
# ------------
189+
model = LitClassifier(args.hidden_dim, args.learning_rate)
190+
191+
# ------------
192+
# training
193+
# ------------
194+
trainer = pl.Trainer.from_argparse_args(args)
195+
trainer.fit(model, train_loader, val_loader)
196+
197+
# ------------
198+
# testing
199+
# ------------
200+
trainer.test(test_dataloaders=test_loader)
201+
202+
203+
if __name__ == "__main__":
204+
cli_main()

0 commit comments

Comments
 (0)