|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import os |
| 15 | + |
| 16 | +import torch |
| 17 | +import torch.nn.functional as F |
| 18 | +from sklearn.datasets import make_classification |
| 19 | +from sklearn.model_selection import train_test_split |
| 20 | +from torch import nn |
| 21 | +from torch.utils.data import DataLoader, Dataset |
| 22 | +from torchmetrics import Accuracy |
| 23 | + |
| 24 | +import pytorch_lightning as pl |
| 25 | +from pytorch_lightning import LightningDataModule, LightningModule, seed_everything |
| 26 | +from pytorch_lightning.callbacks import EarlyStopping |
| 27 | + |
| 28 | +PATH_LEGACY = os.path.dirname(__file__) |
| 29 | + |
| 30 | + |
| 31 | +class SklearnDataset(Dataset): |
| 32 | + def __init__(self, x, y, x_type, y_type): |
| 33 | + self.x = x |
| 34 | + self.y = y |
| 35 | + self._x_type = x_type |
| 36 | + self._y_type = y_type |
| 37 | + |
| 38 | + def __getitem__(self, idx): |
| 39 | + return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type) |
| 40 | + |
| 41 | + def __len__(self): |
| 42 | + return len(self.y) |
| 43 | + |
| 44 | + |
| 45 | +class SklearnDataModule(LightningDataModule): |
| 46 | + def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128): |
| 47 | + super().__init__() |
| 48 | + self.batch_size = batch_size |
| 49 | + self._x, self._y = sklearn_dataset |
| 50 | + self._split_data() |
| 51 | + self._x_type = x_type |
| 52 | + self._y_type = y_type |
| 53 | + |
| 54 | + def _split_data(self): |
| 55 | + self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( |
| 56 | + self._x, self._y, test_size=0.20, random_state=42 |
| 57 | + ) |
| 58 | + self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split( |
| 59 | + self.x_train, self.y_train, test_size=0.40, random_state=42 |
| 60 | + ) |
| 61 | + |
| 62 | + def train_dataloader(self): |
| 63 | + return DataLoader( |
| 64 | + SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type), |
| 65 | + shuffle=True, |
| 66 | + batch_size=self.batch_size, |
| 67 | + ) |
| 68 | + |
| 69 | + def val_dataloader(self): |
| 70 | + return DataLoader( |
| 71 | + SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size |
| 72 | + ) |
| 73 | + |
| 74 | + def test_dataloader(self): |
| 75 | + return DataLoader( |
| 76 | + SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +class ClassifDataModule(SklearnDataModule): |
| 81 | + def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128): |
| 82 | + data = make_classification( |
| 83 | + n_samples=length, |
| 84 | + n_features=num_features, |
| 85 | + n_classes=num_classes, |
| 86 | + n_clusters_per_class=2, |
| 87 | + n_informative=int(num_features / num_classes), |
| 88 | + random_state=42, |
| 89 | + ) |
| 90 | + super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size) |
| 91 | + |
| 92 | + |
| 93 | +class ClassificationModel(LightningModule): |
| 94 | + def __init__(self, num_features=24, num_classes=3, lr=0.01): |
| 95 | + super().__init__() |
| 96 | + self.save_hyperparameters() |
| 97 | + |
| 98 | + self.lr = lr |
| 99 | + for i in range(3): |
| 100 | + setattr(self, f"layer_{i}", nn.Linear(num_features, num_features)) |
| 101 | + setattr(self, f"layer_{i}a", torch.nn.ReLU()) |
| 102 | + setattr(self, "layer_end", nn.Linear(num_features, num_classes)) |
| 103 | + |
| 104 | + self.train_acc = Accuracy() |
| 105 | + self.valid_acc = Accuracy() |
| 106 | + self.test_acc = Accuracy() |
| 107 | + |
| 108 | + def forward(self, x): |
| 109 | + x = self.layer_0(x) |
| 110 | + x = self.layer_0a(x) |
| 111 | + x = self.layer_1(x) |
| 112 | + x = self.layer_1a(x) |
| 113 | + x = self.layer_2(x) |
| 114 | + x = self.layer_2a(x) |
| 115 | + x = self.layer_end(x) |
| 116 | + logits = F.softmax(x, dim=1) |
| 117 | + return logits |
| 118 | + |
| 119 | + def configure_optimizers(self): |
| 120 | + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
| 121 | + return [optimizer], [] |
| 122 | + |
| 123 | + def training_step(self, batch, batch_idx): |
| 124 | + x, y = batch |
| 125 | + logits = self.forward(x) |
| 126 | + loss = F.cross_entropy(logits, y) |
| 127 | + self.log("train_loss", loss, prog_bar=True) |
| 128 | + self.log("train_acc", self.train_acc(logits, y), prog_bar=True) |
| 129 | + return {"loss": loss} |
| 130 | + |
| 131 | + def validation_step(self, batch, batch_idx): |
| 132 | + x, y = batch |
| 133 | + logits = self.forward(x) |
| 134 | + self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False) |
| 135 | + self.log("val_acc", self.valid_acc(logits, y), prog_bar=True) |
| 136 | + |
| 137 | + def test_step(self, batch, batch_idx): |
| 138 | + x, y = batch |
| 139 | + logits = self.forward(x) |
| 140 | + self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False) |
| 141 | + self.log("test_acc", self.test_acc(logits, y), prog_bar=True) |
| 142 | + |
| 143 | + |
| 144 | +def main_train(dir_path, max_epochs: int = 20): |
| 145 | + seed_everything(42) |
| 146 | + stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005) |
| 147 | + trainer = pl.Trainer( |
| 148 | + default_root_dir=dir_path, |
| 149 | + gpus=int(torch.cuda.is_available()), |
| 150 | + precision=(16 if torch.cuda.is_available() else 32), |
| 151 | + checkpoint_callback=True, |
| 152 | + callbacks=[stopping], |
| 153 | + min_epochs=3, |
| 154 | + max_epochs=max_epochs, |
| 155 | + accumulate_grad_batches=2, |
| 156 | + deterministic=True, |
| 157 | + ) |
| 158 | + |
| 159 | + dm = ClassifDataModule() |
| 160 | + model = ClassificationModel() |
| 161 | + trainer.fit(model, datamodule=dm) |
| 162 | + res = trainer.test(model, datamodule=dm) |
| 163 | + assert res[0]["test_loss"] <= 0.7 |
| 164 | + assert res[0]["test_acc"] >= 0.85 |
| 165 | + assert trainer.current_epoch < (max_epochs - 1) |
| 166 | + |
| 167 | + |
| 168 | +if __name__ == "__main__": |
| 169 | + path_dir = os.path.join(PATH_LEGACY, "checkpoints", str(pl.__version__)) |
| 170 | + main_train(path_dir) |
0 commit comments