Skip to content

Commit 94f6c02

Browse files
add lightning code, finetuning whisper, recommender system neural collaborative filtering
1 parent c646ef6 commit 94f6c02

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+17977
-25
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Simple pytorch lightning example
3+
"""
4+
5+
# Imports
6+
import torch
7+
import torch.nn.functional as F # Parameterless functions, like (some) activation functions
8+
import torchvision.datasets as datasets # Standard datasets
9+
import torchvision.transforms as transforms # Transformations we can perform on our dataset for augmentation
10+
from torch import optim # For optimizers like SGD, Adam, etc.
11+
from torch import nn # All neural network modules
12+
from torch.utils.data import (
13+
DataLoader,
14+
) # Gives easier dataset managment by creating mini batches etc.
15+
from tqdm import tqdm # For nice progress bar!
16+
import pytorch_lightning as pl
17+
import torchmetrics
18+
from pytorch_lightning.callbacks import Callback, EarlyStopping
19+
20+
21+
precision = "medium"
22+
torch.set_float32_matmul_precision(precision)
23+
criterion = nn.CrossEntropyLoss()
24+
25+
26+
## use 20% of training data for validation
27+
# train_set_size = int(len(train_dataset) * 0.8)
28+
# valid_set_size = len(train_dataset) - train_set_size
29+
#
30+
## split the train set into two
31+
# seed = torch.Generator().manual_seed(42)
32+
# train_dataset, val_dataset = torch.utils.data.random_split(
33+
# train_dataset, [train_set_size, valid_set_size], generator=seed
34+
# )
35+
36+
37+
class CNNLightning(pl.LightningModule):
38+
def __init__(self, lr=3e-4, in_channels=1, num_classes=10):
39+
super().__init__()
40+
self.lr = lr
41+
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
42+
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
43+
self.conv1 = nn.Conv2d(
44+
in_channels=in_channels,
45+
out_channels=8,
46+
kernel_size=3,
47+
stride=1,
48+
padding=1,
49+
)
50+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
51+
self.conv2 = nn.Conv2d(
52+
in_channels=8,
53+
out_channels=16,
54+
kernel_size=3,
55+
stride=1,
56+
padding=1,
57+
)
58+
self.fc1 = nn.Linear(16 * 7 * 7, num_classes)
59+
self.lr = lr
60+
61+
def training_step(self, batch, batch_idx):
62+
x, y = batch
63+
y_hat = self._common_step(x, batch_idx)
64+
loss = criterion(y_hat, y)
65+
accuracy = self.train_acc(y_hat, y)
66+
self.log(
67+
"train_acc_step",
68+
self.train_acc,
69+
on_step=True,
70+
on_epoch=False,
71+
prog_bar=True,
72+
)
73+
return loss
74+
75+
def training_epoch_end(self, outputs):
76+
self.train_acc.reset()
77+
78+
def test_step(self, batch, batch_idx):
79+
x, y = batch
80+
y_hat = self._common_step(x, batch_idx)
81+
loss = F.cross_entropy(y_hat, y)
82+
accuracy = self.test_acc(y_hat, y)
83+
self.log("test_loss", loss, on_step=True)
84+
self.log("test_acc", accuracy, on_step=True)
85+
86+
def validation_step(self, batch, batch_idx):
87+
x, y = batch
88+
y_hat = self._common_step(x, batch_idx)
89+
loss = F.cross_entropy(y_hat, y)
90+
accuracy = self.test_acc(y_hat, y)
91+
self.log("val_loss", loss, on_step=True)
92+
self.log("val_acc", accuracy, on_step=True)
93+
94+
def predict_step(self, batch, batch_idx):
95+
x, y = batch
96+
y_hat = self._common_step(x)
97+
return y_hat
98+
99+
def _common_step(self, x, batch_idx):
100+
x = self.pool(F.relu(self.conv1(x)))
101+
x = self.pool(F.relu(self.conv2(x)))
102+
x = x.reshape(x.shape[0], -1)
103+
y_hat = self.fc1(x)
104+
return y_hat
105+
106+
def configure_optimizers(self):
107+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
108+
return optimizer
109+
110+
111+
class MNISTDataModule(pl.LightningDataModule):
112+
def __init__(self, batch_size=512):
113+
super().__init__()
114+
self.batch_size = batch_size
115+
116+
def setup(self, stage):
117+
mnist_full = train_dataset = datasets.MNIST(
118+
root="dataset/", train=True, transform=transforms.ToTensor(), download=True
119+
)
120+
self.mnist_test = datasets.MNIST(
121+
root="dataset/", train=False, transform=transforms.ToTensor(), download=True
122+
)
123+
self.mnist_train, self.mnist_val = torch.utils.data.random_split(
124+
mnist_full, [55000, 5000]
125+
)
126+
127+
def train_dataloader(self):
128+
return DataLoader(
129+
self.mnist_train,
130+
batch_size=self.batch_size,
131+
num_workers=6,
132+
shuffle=True,
133+
)
134+
135+
def val_dataloader(self):
136+
return DataLoader(
137+
self.mnist_val, batch_size=self.batch_size, num_workers=2, shuffle=False
138+
)
139+
140+
def test_dataloader(self):
141+
return DataLoader(
142+
self.mnist_test, batch_size=self.batch_size, num_workers=2, shuffle=False
143+
)
144+
145+
146+
class MyPrintingCallback(Callback):
147+
def on_train_start(self, trainer, pl_module):
148+
print("Training is starting")
149+
150+
def on_train_end(self, trainer, pl_module):
151+
print("Training is ending")
152+
153+
154+
# Set device
155+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156+
157+
# Load Data
158+
if __name__ == "__main__":
159+
# Initialize network
160+
model_lightning = CNNLightning()
161+
162+
trainer = pl.Trainer(
163+
#fast_dev_run=True,
164+
# overfit_batches=3,
165+
max_epochs=5,
166+
precision=16,
167+
accelerator="gpu",
168+
devices=[0,1],
169+
callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
170+
auto_lr_find=True,
171+
enable_model_summary=True,
172+
profiler="simple",
173+
strategy="deepspeed_stage_1",
174+
# accumulate_grad_batches=2,
175+
# auto_scale_batch_size="binsearch",
176+
# log_every_n_steps=1,
177+
)
178+
179+
dm = MNISTDataModule()
180+
181+
# trainer tune first to find best batch size and lr
182+
trainer.tune(model_lightning, dm)
183+
184+
trainer.fit(
185+
model=model_lightning,
186+
datamodule=dm,
187+
)
188+
189+
# test model on test loader from LightningDataModule
190+
trainer.test(model=model_lightning, datamodule=dm)

0 commit comments

Comments
 (0)