Skip to content

Commit c832617

Browse files
Merge pull request #1265 from dawidborycki/LP-PyTorch-Digit-Classification-Training
LP: Using PyTorch for training a digit classifier
2 parents 857f2ab + 2b67a64 commit c832617

File tree

9 files changed

+455
-0
lines changed

9 files changed

+455
-0
lines changed
573 KB
Loading
360 KB
Loading
395 KB
Loading
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
---
2+
title: Learn how to train the PyTorch model for digit classification
3+
minutes_to_complete: 40
4+
5+
who_is_this_for: This is an introductory topic for software developers interested in learning how to use PyTorch to train a feedforward neural network for digit classification.
6+
7+
learning_objectives:
8+
- Download and prepare the dataset.
9+
- Train a neural network using PyTorch.
10+
11+
prerequisites:
12+
- A x86_64 or Apple development machine with Code Editor (we recommend Visual Studio Code).
13+
14+
author_primary: Dawid Borycki
15+
16+
### Tags
17+
skilllevels: Introductory
18+
subjects: Neural Networks
19+
armips:
20+
- Cortex-A
21+
- Cortex-X
22+
operatingsystems:
23+
- Windows
24+
- Linux
25+
- MacOS
26+
tools_software_languages:
27+
- Android Studio
28+
- Coding
29+
30+
### FIXED, DO NOT MODIFY
31+
# ================================================================================
32+
weight: 1 # _index.md always has weight of 1 to order correctly
33+
layout: "learningpathall" # All files under learning paths have this same wrapper
34+
learning_path_main_page: "yes" # This should be surfaced when looking for related content. Only set for _index.md of learning path content.
35+
---
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
---
2+
# ================================================================================
3+
# Edit
4+
# ================================================================================
5+
6+
next_step_guidance: >
7+
Proceed to Get Started with Arm Performance Studio for mobile to continue learning about Android performance analysis.
8+
9+
# 1-3 sentence recommendation outlining how the reader can generally keep learning about these topics, and a specific explanation of why the next step is being recommended.
10+
11+
recommended_path: "/learning-paths/smartphones-and-mobile/ams/"
12+
13+
# Link to the next learning path being recommended(For example this could be /learning-paths/servers-and-cloud-computing/mongodb).
14+
15+
16+
# further_reading links to references related to this path. Can be:
17+
# Manuals for a tool / software mentioned (type: documentation)
18+
# Blog about related topics (type: blog)
19+
# General online references (type: website)
20+
21+
further_reading:
22+
- resource:
23+
title: PyTorch
24+
link: https://pytorch.org
25+
type: documentation
26+
- resource:
27+
title: MNIST
28+
link: https://en.wikipedia.org/wiki/MNIST_database
29+
type: website
30+
- resource:
31+
title: Visual Studio Code
32+
link: https://code.visualstudio.com
33+
type: website
34+
35+
36+
37+
# ================================================================================
38+
# FIXED, DO NOT MODIFY
39+
# ================================================================================
40+
weight: 21 # set to always be larger than the content in this path, and one more than 'review'
41+
title: "Next Steps" # Always the same
42+
layout: "learningpathall" # All files under learning paths have this same wrapper
43+
---
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
---
2+
# ================================================================================
3+
# Edit
4+
# ================================================================================
5+
6+
# Always 3 questions. Should try to test the reader's knowledge, and reinforce the key points you want them to remember.
7+
# question: A one sentence question
8+
# answers: The correct answers (from 2-4 answer options only). Should be surrounded by quotes.
9+
# correct_answer: An integer indicating what answer is correct (index starts from 0)
10+
# explanation: A short (1-3 sentence) explanation of why the correct answer is correct. Can add additional context if desired
11+
12+
13+
review:
14+
- questions:
15+
question: >
16+
What is TorchScript used for in the process described above?
17+
answers:
18+
- To optimize the model’s weights and biases during training.
19+
- To visualize the model’s predictions on new data.
20+
- To save the model’s architecture and parameters in a portable format.
21+
- To increase the learning rate during training
22+
correct_answer: 3
23+
explanation: >
24+
TorchScript is used to serialize both the model’s architecture and its learned parameters, making the model portable and independent of the original class definition. This simplifies deployment and allows the model to be loaded and used in different environments without needing the original code.
25+
- questions:
26+
question: >
27+
Which loss function was used to train the PyTorch model on the MNIST dataset?
28+
answers:
29+
- Mean Squared Error Loss
30+
- CrossEntropyLoss
31+
- Hinge Loss
32+
- Binary Cross-Entropy Loss
33+
correct_answer: 2
34+
explanation: >
35+
The CrossEntropyLoss function was used to train the model because it is suitable for multi-class classification tasks like digit classification. It measures the difference between the predicted probabilities and the true class labels, helping the model learn to make accurate predictions.
36+
- questions:
37+
question: >
38+
Why do we set the model to evaluation mode during inference?
39+
answers:
40+
- To increase the learning rate.
41+
- To prevent changes to the model’s architecture.
42+
- To ensure layers like dropout and batch normalization behave correctly.
43+
- To load additional data for training
44+
correct_answer: 3
45+
explanation: >
46+
Setting the model to evaluation mode (model.eval()) ensures that certain layers, such as dropout and batch normalization, function correctly during inference. In evaluation mode, dropout is disabled, and batch normalization uses running averages instead of batch statistics, providing consistent and accurate predictions.
47+
48+
# ================================================================================
49+
# FIXED, DO NOT MODIFY
50+
# ================================================================================
51+
title: "Review" # Always the same title
52+
weight: 20 # Set to always be larger than the content in this path
53+
layout: "learningpathall" # All files under learning paths have this same wrapper
54+
---
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
---
2+
# User change
3+
title: "Datasets and training"
4+
5+
weight: 3
6+
7+
layout: "learningpathall"
8+
---
9+
10+
We start by downloading the MNIST dataset. Proceed as follows:
11+
1. Open, the pytorch-digits.ipynb you created in this [Learning Path](learning-paths/cross-platform/pytorch-digit-classification-architecture).
12+
2. Add the following statements:
13+
14+
```Python
15+
from torchvision import transforms, datasets
16+
from torch.utils.data import DataLoader
17+
18+
# Training data
19+
training_data = datasets.MNIST(
20+
root="data",
21+
train=True,
22+
download=True,
23+
transform=transforms.ToTensor()
24+
)
25+
26+
# Test data
27+
test_data = datasets.MNIST(
28+
root="data",
29+
train=False,
30+
download=True,
31+
transform=transforms.ToTensor()
32+
)
33+
34+
# Dataloaders
35+
batch_size = 32
36+
37+
train_dataloader = DataLoader(training_data, batch_size=batch_size)
38+
test_dataloader = DataLoader(test_data, batch_size=batch_size)
39+
```
40+
41+
The above code snippet downloads the MNIST dataset, transforms the images into tensors, and sets up data loaders for training and testing. Specifically, the datasets.MNIST function is used to download the MNIST dataset, with train=True indicating training data and train=False indicating test data. The transform=transforms.ToTensor() argument converts each image in the dataset into a PyTorch tensor, which is necessary for model training and evaluation.
42+
43+
The DataLoader wraps the datasets and allows efficient loading of data in batches. It handles data shuffling, batching, and parallel loading. Here, the train_dataloader and test_dataloader are created with a batch_size of 32, meaning they will load 32 images per batch during training and testing.
44+
45+
This setup prepares the training and test datasets for use in a machine learning model, enabling efficient data handling and model training in PyTorch.
46+
47+
To run the above code, you will need to install certifi package:
48+
```console
49+
pip install certifi
50+
```
51+
52+
certifi is a Python package that provides the Mozilla’s root certificates, which are essential for ensuring the SSL connections are secure. If you’re using macOS, you may also need to install the certificates by running:
53+
54+
```console
55+
/Applications/Python\ 3.x/Install\ Certificates.command
56+
```
57+
58+
Make sure to replace `x` by the number of Python version you have installed.
59+
60+
After running the code you will see the output that might look like shown below:
61+
62+
![image](Figures/01.png)
63+
64+
# Training
65+
Now, we have all the tools needed to train the model. We first specify the loss function and the optimizer:
66+
67+
```Python
68+
learning_rate = 1e-3
69+
70+
loss_fn = nn.CrossEntropyLoss()
71+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
72+
```
73+
74+
We use CrossEntropyLoss as the loss function and the Adam optimizer for training. The learning rate is set to 1e-3.
75+
76+
Next, we define the methods for training and evaluating our feedforward neural network:
77+
78+
```Python
79+
def train_loop(dataloader, model, loss_fn, optimizer):
80+
size = len(dataloader.dataset)
81+
for batch, (x, y) in enumerate(dataloader):
82+
# Compute prediction and loss
83+
pred = model(x)
84+
loss = loss_fn(pred, y)
85+
86+
# Backpropagation
87+
optimizer.zero_grad()
88+
loss.backward()
89+
optimizer.step()
90+
91+
def test_loop(dataloader, model, loss_fn):
92+
size = len(dataloader.dataset)
93+
num_batches = len(dataloader)
94+
test_loss, correct = 0, 0
95+
96+
with torch.no_grad():
97+
for x, y in dataloader:
98+
pred = model(x)
99+
test_loss += loss_fn(pred, y).item()
100+
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
101+
102+
test_loss /= num_batches
103+
correct /= size
104+
105+
print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
106+
```
107+
108+
The first method, train_loop, uses the backpropagation algorithm to optimize the trainable parameters and minimize the prediction error of the neural network. The second method, test_loop, calculates the neural network error using the test images and displays the accuracy and loss values.
109+
110+
We can now invoke these methods to train and evaluate the model. Similarly to TensorFlow, we use 10 epochs.
111+
112+
```Python
113+
epochs = 10
114+
115+
for t in range(epochs):
116+
print(f"Epoch {t+1}:")
117+
train_loop(train_dataloader, model, loss_fn, optimizer)
118+
test_loop(test_dataloader, model, loss_fn)
119+
```
120+
121+
After running this code, you will see the following output that shows the training progress.
122+
123+
![image](Figures/02.png)
124+
125+
Once the training is complete, you will see something like the following:
126+
127+
```output
128+
Epoch 10:
129+
Accuracy: 95.4%, Avg loss: 1.507491
130+
```
131+
132+
which shows the model achieved around 95% of accuracy.
133+
134+
# Saving the model
135+
Once the model is trained, we can save it. There are various approaches for this. In PyTorch, you can save both the model’s structure and its weights to the same file using the torch.save() function. Alternatively, you can save only the weights (parameters) of the model, not the model architecture itself. This requires you to have the model’s architecture defined separately when loading. To save the model weights, you can use the following command:
136+
137+
```Python
138+
torch.save(model.state_dict(), "model_weights.pth").
139+
```
140+
141+
However, PyTorch does not save the definition of the class itself. When you load the model using torch.load(), PyTorch needs to know the class definition to recreate the model object.
142+
143+
Therefore, when you later want to use the saved model for inference, you will need to provide the definition of the model class.
144+
145+
Alternatively, you can use TorchScript, which serializes both the architecture and weights into a single file that can be loaded without needing the original class definition. This is particularly useful for deploying models to production or sharing models without code dependencies.
146+
147+
Here, we use TorchScript and save the model using the following commands:
148+
149+
```Python
150+
# Set model to evaluation mode
151+
model.eval()
152+
153+
# Trace the model with an example input
154+
traced_model = torch.jit.trace(model, torch.rand(1, 1, 28, 28))
155+
156+
# Save the traced model
157+
traced_model.save("model.pth")
158+
```
159+
160+
The above commands set the model to evaluation mode (model.eval()), then trace the model, and save it. Tracing is useful for converting models with static computation graphs to TorchScript, making them portable and independent of the original class definition.
161+
162+
Setting the model to evaluation mode before tracing is important for several reasons:
163+
1. Behavior of Layers like Dropout and BatchNorm:
164+
* Dropout. During training (model.train()), dropout randomly zeroes out some of the activations to prevent overfitting. During evaluation (model.eval()), dropout is turned off, and all activations are used.
165+
* BatchNorm. During training, Batch Normalization layers use batch statistics to normalize the input. During evaluation, they use running averages calculated during training.
166+
167+
2. Consistent Inference Behavior. By setting the model to eval() mode, you ensure that the traced model will behave consistently during inference, as it will not use dropout or batch statistics that are inappropriate for inference.
168+
169+
3. Correct Tracing. Tracing captures the operations performed by the model using a given input. If the model is in training mode, the traced graph may include operations related to dropout and batch normalization updates. These operations can affect the correctness and performance of the model during inference.
170+
171+
In the next step, we will use the saved model for inference.

0 commit comments

Comments
 (0)