|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| - |
| 14 | +# Source code inspired from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html |
| 15 | +import matplotlib.pyplot as plt |
| 16 | +import numpy as np |
| 17 | +import torch |
| 18 | +import torch.nn.functional as F |
| 19 | +import torch.optim as optim |
15 | 20 | import torchvision
|
16 |
| -from timm import create_model |
| 21 | +import torchvision.transforms as transforms |
| 22 | +from torch import nn |
17 | 23 | from torch.utils.data import DataLoader
|
18 | 24 | from torchmetrics.classification import MulticlassAccuracy
|
19 |
| -from torchvision import transforms as T |
20 | 25 |
|
21 | 26 | from gradsflow import AutoDataset, Model
|
22 |
| -from gradsflow.callbacks import ( |
23 |
| - CometCallback, |
24 |
| - CSVLogger, |
25 |
| - EmissionTrackerCallback, |
26 |
| - ModelCheckpoint, |
27 |
| - WandbCallback, |
28 |
| -) |
29 |
| -from gradsflow.data.common import random_split_dataset |
30 |
| - |
31 |
| -# Replace dataloaders with your custom dataset and you are all set to train your model |
| 27 | +from gradsflow.callbacks import CSVLogger, ModelCheckpoint |
| 28 | + |
| 29 | +# Replace dataloaders with your custom dataset, and you are all set to train your model |
32 | 30 | image_size = (64, 64)
|
33 | 31 | batch_size = 4
|
34 | 32 |
|
35 |
| -to_rgb = lambda x: x.convert("RGB") |
| 33 | +transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
| 34 | + |
| 35 | +trainset = torchvision.datasets.CIFAR10(root="~/data", train=True, download=True, transform=transform) |
| 36 | +train_dl = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) |
36 | 37 |
|
37 |
| -augs = T.Compose([to_rgb, T.AutoAugment(), T.Resize(image_size), T.ToTensor()]) |
38 |
| -data = torchvision.datasets.CIFAR10("~/data", download=True, transform=augs) |
39 |
| -train_data, val_data = random_split_dataset(data, 0.99) |
40 |
| -train_dl = DataLoader(train_data, batch_size=batch_size) |
41 |
| -val_dl = DataLoader(val_data, batch_size=batch_size) |
42 |
| -num_classes = len(data.classes) |
| 38 | +testset = torchvision.datasets.CIFAR10(root="~/data", train=False, download=True, transform=transform) |
| 39 | +val_dl = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) |
| 40 | +num_classes = len(trainset.classes) |
43 | 41 | cbs = [
|
44 | 42 | CSVLogger(
|
45 | 43 | verbose=True,
|
46 | 44 | ),
|
47 | 45 | ModelCheckpoint(),
|
48 |
| - EmissionTrackerCallback(), |
| 46 | + # EmissionTrackerCallback(), |
49 | 47 | # CometCallback(offline=True),
|
50 |
| - WandbCallback(), |
| 48 | + # WandbCallback(), |
51 | 49 | ]
|
52 | 50 |
|
| 51 | + |
| 52 | +def imshow(img): |
| 53 | + img = img / 2 + 0.5 # unnormalize |
| 54 | + npimg = img.numpy() |
| 55 | + plt.imshow(np.transpose(npimg, (1, 2, 0))) |
| 56 | + plt.show() |
| 57 | + |
| 58 | + |
| 59 | +class Net(nn.Module): |
| 60 | + def __init__(self): |
| 61 | + super().__init__() |
| 62 | + self.conv1 = nn.Conv2d(3, 6, 5) |
| 63 | + self.pool = nn.MaxPool2d(2, 2) |
| 64 | + self.conv2 = nn.Conv2d(6, 16, 5) |
| 65 | + self.fc1 = nn.Linear(16 * 5 * 5, 120) |
| 66 | + self.fc2 = nn.Linear(120, 84) |
| 67 | + self.fc3 = nn.Linear(84, 10) |
| 68 | + |
| 69 | + def forward(self, x): |
| 70 | + x = self.pool(F.relu(self.conv1(x))) |
| 71 | + x = self.pool(F.relu(self.conv2(x))) |
| 72 | + x = torch.flatten(x, 1) # flatten all dimensions except batch |
| 73 | + x = F.relu(self.fc1(x)) |
| 74 | + x = F.relu(self.fc2(x)) |
| 75 | + x = self.fc3(x) |
| 76 | + return x |
| 77 | + |
| 78 | + |
53 | 79 | if __name__ == "__main__":
|
54 | 80 | autodataset = AutoDataset(train_dl, val_dl, num_classes=num_classes)
|
55 |
| - cnn = create_model("resnet18", pretrained=False, num_classes=num_classes) |
| 81 | + net = Net() |
| 82 | + model = Model(net) |
| 83 | + criterion = nn.CrossEntropyLoss() |
| 84 | + |
| 85 | + model.compile( |
| 86 | + criterion, |
| 87 | + optim.SGD, |
| 88 | + optimizer_config={"momentum": 0.9}, |
| 89 | + learning_rate=0.001, |
| 90 | + metrics=[MulticlassAccuracy(autodataset.num_classes)], |
| 91 | + ) |
| 92 | + model.fit(autodataset, max_epochs=2, callbacks=cbs) |
| 93 | + |
| 94 | + dataiter = iter(val_dl) |
| 95 | + images, labels = next(dataiter) |
| 96 | + |
| 97 | + # print images |
| 98 | + # imshow(torchvision.utils.make_grid(images)) |
| 99 | + print("GroundTruth: ", " ".join(f"{trainset.classes[labels[j]]:5s}" for j in range(4))) |
56 | 100 |
|
57 |
| - model = Model(cnn) |
| 101 | + outputs = net(images) |
| 102 | + _, predicted = torch.max(outputs, 1) |
58 | 103 |
|
59 |
| - model.compile("crossentropyloss", "adam", metrics=[MulticlassAccuracy(autodataset.num_classes)]) |
60 |
| - model.fit(autodataset, max_epochs=10, steps_per_epoch=10, callbacks=cbs) |
| 104 | + print("Predicted: ", " ".join(f"{trainset.classes[predicted[j]]:5s}" for j in range(4))) |
0 commit comments