|
| 1 | +""" |
| 2 | +Simple pytorch lightning example |
| 3 | +""" |
| 4 | + |
| 5 | +# Imports |
| 6 | +import torch |
| 7 | +import torch.nn.functional as F # Parameterless functions, like (some) activation functions |
| 8 | +import torchvision.datasets as datasets # Standard datasets |
| 9 | +import torchvision.transforms as transforms # Transformations we can perform on our dataset for augmentation |
| 10 | +from torch import optim # For optimizers like SGD, Adam, etc. |
| 11 | +from torch import nn # All neural network modules |
| 12 | +from torch.utils.data import ( |
| 13 | + DataLoader, |
| 14 | +) # Gives easier dataset managment by creating mini batches etc. |
| 15 | +from tqdm import tqdm # For nice progress bar! |
| 16 | +import pytorch_lightning as pl |
| 17 | +import torchmetrics |
| 18 | +from pytorch_lightning.callbacks import Callback, EarlyStopping |
| 19 | + |
| 20 | + |
| 21 | +precision = "medium" |
| 22 | +torch.set_float32_matmul_precision(precision) |
| 23 | +criterion = nn.CrossEntropyLoss() |
| 24 | + |
| 25 | + |
| 26 | +## use 20% of training data for validation |
| 27 | +# train_set_size = int(len(train_dataset) * 0.8) |
| 28 | +# valid_set_size = len(train_dataset) - train_set_size |
| 29 | +# |
| 30 | +## split the train set into two |
| 31 | +# seed = torch.Generator().manual_seed(42) |
| 32 | +# train_dataset, val_dataset = torch.utils.data.random_split( |
| 33 | +# train_dataset, [train_set_size, valid_set_size], generator=seed |
| 34 | +# ) |
| 35 | + |
| 36 | + |
| 37 | +class CNNLightning(pl.LightningModule): |
| 38 | + def __init__(self, lr=3e-4, in_channels=1, num_classes=10): |
| 39 | + super().__init__() |
| 40 | + self.lr = lr |
| 41 | + self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) |
| 42 | + self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) |
| 43 | + self.conv1 = nn.Conv2d( |
| 44 | + in_channels=in_channels, |
| 45 | + out_channels=8, |
| 46 | + kernel_size=3, |
| 47 | + stride=1, |
| 48 | + padding=1, |
| 49 | + ) |
| 50 | + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
| 51 | + self.conv2 = nn.Conv2d( |
| 52 | + in_channels=8, |
| 53 | + out_channels=16, |
| 54 | + kernel_size=3, |
| 55 | + stride=1, |
| 56 | + padding=1, |
| 57 | + ) |
| 58 | + self.fc1 = nn.Linear(16 * 7 * 7, num_classes) |
| 59 | + self.lr = lr |
| 60 | + |
| 61 | + def training_step(self, batch, batch_idx): |
| 62 | + x, y = batch |
| 63 | + y_hat = self._common_step(x, batch_idx) |
| 64 | + loss = criterion(y_hat, y) |
| 65 | + accuracy = self.train_acc(y_hat, y) |
| 66 | + self.log( |
| 67 | + "train_acc_step", |
| 68 | + self.train_acc, |
| 69 | + on_step=True, |
| 70 | + on_epoch=False, |
| 71 | + prog_bar=True, |
| 72 | + ) |
| 73 | + return loss |
| 74 | + |
| 75 | + def training_epoch_end(self, outputs): |
| 76 | + self.train_acc.reset() |
| 77 | + |
| 78 | + def test_step(self, batch, batch_idx): |
| 79 | + x, y = batch |
| 80 | + y_hat = self._common_step(x, batch_idx) |
| 81 | + loss = F.cross_entropy(y_hat, y) |
| 82 | + accuracy = self.test_acc(y_hat, y) |
| 83 | + self.log("test_loss", loss, on_step=True) |
| 84 | + self.log("test_acc", accuracy, on_step=True) |
| 85 | + |
| 86 | + def validation_step(self, batch, batch_idx): |
| 87 | + x, y = batch |
| 88 | + y_hat = self._common_step(x, batch_idx) |
| 89 | + loss = F.cross_entropy(y_hat, y) |
| 90 | + accuracy = self.test_acc(y_hat, y) |
| 91 | + self.log("val_loss", loss, on_step=True) |
| 92 | + self.log("val_acc", accuracy, on_step=True) |
| 93 | + |
| 94 | + def predict_step(self, batch, batch_idx): |
| 95 | + x, y = batch |
| 96 | + y_hat = self._common_step(x) |
| 97 | + return y_hat |
| 98 | + |
| 99 | + def _common_step(self, x, batch_idx): |
| 100 | + x = self.pool(F.relu(self.conv1(x))) |
| 101 | + x = self.pool(F.relu(self.conv2(x))) |
| 102 | + x = x.reshape(x.shape[0], -1) |
| 103 | + y_hat = self.fc1(x) |
| 104 | + return y_hat |
| 105 | + |
| 106 | + def configure_optimizers(self): |
| 107 | + optimizer = optim.Adam(self.parameters(), lr=self.lr) |
| 108 | + return optimizer |
| 109 | + |
| 110 | + |
| 111 | +class MNISTDataModule(pl.LightningDataModule): |
| 112 | + def __init__(self, batch_size=512): |
| 113 | + super().__init__() |
| 114 | + self.batch_size = batch_size |
| 115 | + |
| 116 | + def setup(self, stage): |
| 117 | + mnist_full = train_dataset = datasets.MNIST( |
| 118 | + root="dataset/", train=True, transform=transforms.ToTensor(), download=True |
| 119 | + ) |
| 120 | + self.mnist_test = datasets.MNIST( |
| 121 | + root="dataset/", train=False, transform=transforms.ToTensor(), download=True |
| 122 | + ) |
| 123 | + self.mnist_train, self.mnist_val = torch.utils.data.random_split( |
| 124 | + mnist_full, [55000, 5000] |
| 125 | + ) |
| 126 | + |
| 127 | + def train_dataloader(self): |
| 128 | + return DataLoader( |
| 129 | + self.mnist_train, |
| 130 | + batch_size=self.batch_size, |
| 131 | + num_workers=6, |
| 132 | + shuffle=True, |
| 133 | + ) |
| 134 | + |
| 135 | + def val_dataloader(self): |
| 136 | + return DataLoader( |
| 137 | + self.mnist_val, batch_size=self.batch_size, num_workers=2, shuffle=False |
| 138 | + ) |
| 139 | + |
| 140 | + def test_dataloader(self): |
| 141 | + return DataLoader( |
| 142 | + self.mnist_test, batch_size=self.batch_size, num_workers=2, shuffle=False |
| 143 | + ) |
| 144 | + |
| 145 | + |
| 146 | +class MyPrintingCallback(Callback): |
| 147 | + def on_train_start(self, trainer, pl_module): |
| 148 | + print("Training is starting") |
| 149 | + |
| 150 | + def on_train_end(self, trainer, pl_module): |
| 151 | + print("Training is ending") |
| 152 | + |
| 153 | + |
| 154 | +# Set device |
| 155 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 156 | + |
| 157 | +# Load Data |
| 158 | +if __name__ == "__main__": |
| 159 | + # Initialize network |
| 160 | + model_lightning = CNNLightning() |
| 161 | + |
| 162 | + trainer = pl.Trainer( |
| 163 | + #fast_dev_run=True, |
| 164 | + # overfit_batches=3, |
| 165 | + max_epochs=5, |
| 166 | + precision=16, |
| 167 | + accelerator="gpu", |
| 168 | + devices=[0,1], |
| 169 | + callbacks=[EarlyStopping(monitor="val_loss", mode="min")], |
| 170 | + auto_lr_find=True, |
| 171 | + enable_model_summary=True, |
| 172 | + profiler="simple", |
| 173 | + strategy="deepspeed_stage_1", |
| 174 | + # accumulate_grad_batches=2, |
| 175 | + # auto_scale_batch_size="binsearch", |
| 176 | + # log_every_n_steps=1, |
| 177 | + ) |
| 178 | + |
| 179 | + dm = MNISTDataModule() |
| 180 | + |
| 181 | + # trainer tune first to find best batch size and lr |
| 182 | + trainer.tune(model_lightning, dm) |
| 183 | + |
| 184 | + trainer.fit( |
| 185 | + model=model_lightning, |
| 186 | + datamodule=dm, |
| 187 | + ) |
| 188 | + |
| 189 | + # test model on test loader from LightningDataModule |
| 190 | + trainer.test(model=model_lightning, datamodule=dm) |
0 commit comments