Skip to content

Commit 7cc4cc5

Browse files
Bordalexierule
authored andcommitted
Legacy: simple classif training (#8535)
* simple_classif_training * fix test * pt1.6 * automate Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: thomas chaton <[email protected]> (cherry picked from commit 0778ffb)
1 parent d5acc0c commit 7cc4cc5

File tree

8 files changed

+320
-163
lines changed

8 files changed

+320
-163
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Create Legacy Ckpts
2+
3+
# https://help.github.com/en/actions/reference/events-that-trigger-workflows
4+
on:
5+
workflow_dispatch:
6+
7+
jobs:
8+
create-legacy-ckpts:
9+
runs-on: ubuntu-20.04
10+
steps:
11+
- uses: actions/checkout@v2
12+
13+
- uses: actions/setup-python@v2
14+
with:
15+
python-version: 3.8
16+
17+
- name: Install dependencies
18+
run: |
19+
pip install -r requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
20+
pip install awscli
21+
22+
- name: Configure AWS credentials
23+
uses: aws-actions/configure-aws-credentials@v1
24+
with:
25+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
26+
aws-secret-access-key: ${{ secrets.AWS_SECRET_KEY_ID }}
27+
aws-region: us-east-1
28+
29+
- name: Generate checkpoint
30+
run: |
31+
while IFS= read -r line; do
32+
bash legacy/generate_checkpoints.sh $line
33+
done <<< $(cat legacy/back-compatible-versions.txt)
34+
35+
- name: Push files to S3
36+
working-directory: ./legacy
37+
run: |
38+
aws s3 sync legacy/checkpoints/ s3://pl-public-data/legacy/checkpoints/
39+
zip -r checkpoints.zip checkpoints
40+
aws s3 cp checkpoints.zip s3://pl-public-data/legacy/ --acl public-read

legacy/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ unzip -o checkpoints.zip
1414
To back populate collection with past version you can use following bash:
1515

1616
```bash
17-
bash generate_checkpoints.sh 1.0.2 1.0.3 1.0.4
17+
bash generate_checkpoints.sh "1.3.7" "1.3.8"
1818
zip -r checkpoints.zip checkpoints/
1919
```
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
1.0.0
2+
1.0.1
3+
1.0.2
4+
1.0.3
5+
1.0.4
6+
1.0.5
7+
1.0.6
8+
1.0.7
9+
1.0.8
10+
1.1.0
11+
1.1.1
12+
1.1.2
13+
1.1.3
14+
1.1.4
15+
1.1.5
16+
1.1.6
17+
1.1.7
18+
1.1.8
19+
1.2.0
20+
1.2.1
21+
1.2.2
22+
1.2.3
23+
1.2.4
24+
1.2.5
25+
1.2.6
26+
1.2.7
27+
1.2.8
28+
1.2.10
29+
1.3.0
30+
1.3.1
31+
1.3.2
32+
1.3.3
33+
1.3.4
34+
1.3.5
35+
1.3.6
36+
1.3.7
37+
1.3.8
38+
1.4.0
39+
1.4.1

legacy/generate_checkpoints.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# Sample call:
33
# bash generate_checkpoints.sh 1.0.2 1.0.3 1.0.4
44

5+
set -e
6+
57
LEGACY_PATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
6-
FROZEN_MIN_PT_VERSION="1.4"
8+
FROZEN_MIN_PT_VERSION="1.6"
79

810
echo $LEGACY_PATH
911
# install some PT version here so it does not need to reinstalled for each env
10-
pip install virtualenv "torch==1.5" --quiet --no-cache-dir
12+
pip install virtualenv "torch==1.6" --quiet
1113

1214
ENV_PATH="$LEGACY_PATH/vEnv"
1315

@@ -23,14 +25,14 @@ do
2325
# activate and install PL version
2426
source "$ENV_PATH/bin/activate"
2527
# there are problem to load ckpt in older versions since they are saved the newer versions
26-
pip install "pytorch_lightning==$ver" "torch==$FROZEN_MIN_PT_VERSION" --quiet --no-cache-dir
28+
pip install "pytorch_lightning==$ver" "torch==$FROZEN_MIN_PT_VERSION" "torchmetrics" "scikit-learn" --quiet
2729

2830
python --version
2931
pip --version
3032
pip list | grep torch
3133

32-
python "$LEGACY_PATH/zero_training.py"
33-
cp "$LEGACY_PATH/zero_training.py" ${LEGACY_PATH}/checkpoints/${ver}
34+
python "$LEGACY_PATH/simple_classif_training.py"
35+
cp "$LEGACY_PATH/simple_classif_training.py" ${LEGACY_PATH}/checkpoints/${ver}
3436

3537
mv ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs/version_0/checkpoints/*.ckpt ${LEGACY_PATH}/checkpoints/${ver}/
3638
rm -rf ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs

legacy/simple_classif_training.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import torch
17+
import torch.nn.functional as F
18+
from sklearn.datasets import make_classification
19+
from sklearn.model_selection import train_test_split
20+
from torch import nn
21+
from torch.utils.data import DataLoader, Dataset
22+
from torchmetrics import Accuracy
23+
24+
import pytorch_lightning as pl
25+
from pytorch_lightning import LightningDataModule, LightningModule, seed_everything
26+
from pytorch_lightning.callbacks import EarlyStopping
27+
28+
PATH_LEGACY = os.path.dirname(__file__)
29+
30+
31+
class SklearnDataset(Dataset):
32+
def __init__(self, x, y, x_type, y_type):
33+
self.x = x
34+
self.y = y
35+
self._x_type = x_type
36+
self._y_type = y_type
37+
38+
def __getitem__(self, idx):
39+
return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type)
40+
41+
def __len__(self):
42+
return len(self.y)
43+
44+
45+
class SklearnDataModule(LightningDataModule):
46+
def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128):
47+
super().__init__()
48+
self.batch_size = batch_size
49+
self._x, self._y = sklearn_dataset
50+
self._split_data()
51+
self._x_type = x_type
52+
self._y_type = y_type
53+
54+
def _split_data(self):
55+
self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
56+
self._x, self._y, test_size=0.20, random_state=42
57+
)
58+
self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(
59+
self.x_train, self.y_train, test_size=0.40, random_state=42
60+
)
61+
62+
def train_dataloader(self):
63+
return DataLoader(
64+
SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type),
65+
shuffle=True,
66+
batch_size=self.batch_size,
67+
)
68+
69+
def val_dataloader(self):
70+
return DataLoader(
71+
SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size
72+
)
73+
74+
def test_dataloader(self):
75+
return DataLoader(
76+
SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size
77+
)
78+
79+
80+
class ClassifDataModule(SklearnDataModule):
81+
def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128):
82+
data = make_classification(
83+
n_samples=length,
84+
n_features=num_features,
85+
n_classes=num_classes,
86+
n_clusters_per_class=2,
87+
n_informative=int(num_features / num_classes),
88+
random_state=42,
89+
)
90+
super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size)
91+
92+
93+
class ClassificationModel(LightningModule):
94+
def __init__(self, num_features=24, num_classes=3, lr=0.01):
95+
super().__init__()
96+
self.save_hyperparameters()
97+
98+
self.lr = lr
99+
for i in range(3):
100+
setattr(self, f"layer_{i}", nn.Linear(num_features, num_features))
101+
setattr(self, f"layer_{i}a", torch.nn.ReLU())
102+
setattr(self, "layer_end", nn.Linear(num_features, num_classes))
103+
104+
self.train_acc = Accuracy()
105+
self.valid_acc = Accuracy()
106+
self.test_acc = Accuracy()
107+
108+
def forward(self, x):
109+
x = self.layer_0(x)
110+
x = self.layer_0a(x)
111+
x = self.layer_1(x)
112+
x = self.layer_1a(x)
113+
x = self.layer_2(x)
114+
x = self.layer_2a(x)
115+
x = self.layer_end(x)
116+
logits = F.softmax(x, dim=1)
117+
return logits
118+
119+
def configure_optimizers(self):
120+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
121+
return [optimizer], []
122+
123+
def training_step(self, batch, batch_idx):
124+
x, y = batch
125+
logits = self.forward(x)
126+
loss = F.cross_entropy(logits, y)
127+
self.log("train_loss", loss, prog_bar=True)
128+
self.log("train_acc", self.train_acc(logits, y), prog_bar=True)
129+
return {"loss": loss}
130+
131+
def validation_step(self, batch, batch_idx):
132+
x, y = batch
133+
logits = self.forward(x)
134+
self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False)
135+
self.log("val_acc", self.valid_acc(logits, y), prog_bar=True)
136+
137+
def test_step(self, batch, batch_idx):
138+
x, y = batch
139+
logits = self.forward(x)
140+
self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False)
141+
self.log("test_acc", self.test_acc(logits, y), prog_bar=True)
142+
143+
144+
def main_train(dir_path, max_epochs: int = 20):
145+
seed_everything(42)
146+
stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005)
147+
trainer = pl.Trainer(
148+
default_root_dir=dir_path,
149+
gpus=int(torch.cuda.is_available()),
150+
precision=(16 if torch.cuda.is_available() else 32),
151+
checkpoint_callback=True,
152+
callbacks=[stopping],
153+
min_epochs=3,
154+
max_epochs=max_epochs,
155+
accumulate_grad_batches=2,
156+
deterministic=True,
157+
)
158+
159+
dm = ClassifDataModule()
160+
model = ClassificationModel()
161+
trainer.fit(model, datamodule=dm)
162+
res = trainer.test(model, datamodule=dm)
163+
assert res[0]["test_loss"] <= 0.7
164+
assert res[0]["test_acc"] >= 0.85
165+
assert trainer.current_epoch < (max_epochs - 1)
166+
167+
168+
if __name__ == "__main__":
169+
path_dir = os.path.join(PATH_LEGACY, "checkpoints", str(pl.__version__))
170+
main_train(path_dir)

legacy/zero_training.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

0 commit comments

Comments
 (0)