Skip to content

Commit fce6ea3

Browse files
committed
Review PyTorch MNIST Learning Paths
1 parent c832617 commit fce6ea3

File tree

7 files changed

+95
-69
lines changed

7 files changed

+95
-69
lines changed

content/learning-paths/cross-platform/pytorch-digit-classification-architecture/_index.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
---
22
title: Create a PyTorch model for digit classification
3-
draft: true
4-
cascade:
5-
draft: true
63

74
minutes_to_complete: 40
85

content/learning-paths/cross-platform/pytorch-digit-classification-architecture/_next-steps.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
# ================================================================================
55

66
next_step_guidance: >
7-
Proceed to Get Started with Arm Performance Studio for mobile to continue learning about Android performance analysis.
7+
Continue to learn how to train the model with PyTorch and use it for inference.
88
99
# 1-3 sentence recommendation outlining how the reader can generally keep learning about these topics, and a specific explanation of why the next step is being recommended.
1010

11-
recommended_path: "/learning-paths/smartphones-and-mobile/ams/"
11+
recommended_path: "/learning-paths/cross-platform/pytorch-digit-classification-training/"
1212

1313
# Link to the next learning path being recommended(For example this could be /learning-paths/servers-and-cloud-computing/mongodb).
1414

content/learning-paths/cross-platform/pytorch-digit-classification-training/_index.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,40 @@
11
---
2-
title: Learn how to train the PyTorch model for digit classification
2+
title: Train a PyTorch model for digit classification
33
minutes_to_complete: 40
44

55
who_is_this_for: This is an introductory topic for software developers interested in learning how to use PyTorch to train a feedforward neural network for digit classification.
66

77
learning_objectives:
8-
- Download and prepare the dataset.
8+
- Download and prepare the MNIST dataset.
99
- Train a neural network using PyTorch.
1010

1111
prerequisites:
12-
- A x86_64 or Apple development machine with Code Editor (we recommend Visual Studio Code).
12+
- Any computer which can run Python3 and Visual Studio Code, this can be Windows, Linux, or macOS.
13+
- You should complete [Create a PyTorch model for digit classification](/learning-paths/cross-platform/pytorch-digit-classification-architecture/) before starting this Learning Path.
1314

1415
author_primary: Dawid Borycki
1516

1617
### Tags
1718
skilllevels: Introductory
18-
subjects: Neural Networks
19+
subjects: ML
1920
armips:
2021
- Cortex-A
2122
- Cortex-X
23+
- Neoverse
2224
operatingsystems:
2325
- Windows
2426
- Linux
25-
- MacOS
27+
- macOS
2628
tools_software_languages:
27-
- Android Studio
2829
- Coding
2930

31+
shared_path: true
32+
shared_between:
33+
- servers-and-cloud-computing
34+
- laptops-and-desktops
35+
- smartphones-and-mobile
36+
37+
3038
### FIXED, DO NOT MODIFY
3139
# ================================================================================
3240
weight: 1 # _index.md always has weight of 1 to order correctly

content/learning-paths/cross-platform/pytorch-digit-classification-training/_next-steps.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
# ================================================================================
55

66
next_step_guidance: >
7-
Proceed to Get Started with Arm Performance Studio for mobile to continue learning about Android performance analysis.
7+
Proceed to Use Keras Core with TensorFlow, PyTorch, and JAX backends to continue exploring Machine Learning.
88
99
# 1-3 sentence recommendation outlining how the reader can generally keep learning about these topics, and a specific explanation of why the next step is being recommended.
1010

11-
recommended_path: "/learning-paths/smartphones-and-mobile/ams/"
11+
recommended_path: "/learning-paths/servers-and-cloud-computing/keras-core/"
1212

1313
# Link to the next learning path being recommended(For example this could be /learning-paths/servers-and-cloud-computing/mongodb).
1414

content/learning-paths/cross-platform/pytorch-digit-classification-training/datasets-and-training.md

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ weight: 3
77
layout: "learningpathall"
88
---
99

10-
We start by downloading the MNIST dataset. Proceed as follows:
11-
1. Open, the pytorch-digits.ipynb you created in this [Learning Path](learning-paths/cross-platform/pytorch-digit-classification-architecture).
10+
Start by downloading the MNIST dataset. Proceed as follows:
11+
12+
1. Open, the pytorch-digits.ipynb you created in [Create a PyTorch model for digit classification](/learning-paths/cross-platform/pytorch-digit-classification-architecture/).
13+
1214
2. Add the following statements:
1315

14-
```Python
16+
```python
1517
from torchvision import transforms, datasets
1618
from torch.utils.data import DataLoader
1719

@@ -38,31 +40,33 @@ train_dataloader = DataLoader(training_data, batch_size=batch_size)
3840
test_dataloader = DataLoader(test_data, batch_size=batch_size)
3941
```
4042

41-
The above code snippet downloads the MNIST dataset, transforms the images into tensors, and sets up data loaders for training and testing. Specifically, the datasets.MNIST function is used to download the MNIST dataset, with train=True indicating training data and train=False indicating test data. The transform=transforms.ToTensor() argument converts each image in the dataset into a PyTorch tensor, which is necessary for model training and evaluation.
43+
The above code snippet downloads the MNIST dataset, transforms the images into tensors, and sets up data loaders for training and testing. Specifically, the `datasets.MNIST` function is used to download the MNIST dataset, with `train=True` indicating training data and `train=False` indicating test data. The `transform=transforms.ToTensor()` argument converts each image in the dataset into a PyTorch tensor, which is necessary for model training and evaluation.
4244

4345
The DataLoader wraps the datasets and allows efficient loading of data in batches. It handles data shuffling, batching, and parallel loading. Here, the train_dataloader and test_dataloader are created with a batch_size of 32, meaning they will load 32 images per batch during training and testing.
4446

4547
This setup prepares the training and test datasets for use in a machine learning model, enabling efficient data handling and model training in PyTorch.
4648

4749
To run the above code, you will need to install certifi package:
50+
4851
```console
4952
pip install certifi
5053
```
5154

52-
certifi is a Python package that provides the Mozilla’s root certificates, which are essential for ensuring the SSL connections are secure. If you’re using macOS, you may also need to install the certificates by running:
55+
The certifi Python package provides the Mozilla root certificates, which are essential for ensuring the SSL connections are secure. If you’re using macOS, you may also need to install the certificates by running:
5356

5457
```console
5558
/Applications/Python\ 3.x/Install\ Certificates.command
5659
```
5760

58-
Make sure to replace `x` by the number of Python version you have installed.
61+
Make sure to replace `x` with the number of Python version you have installed.
5962

6063
After running the code you will see the output that might look like shown below:
6164

6265
![image](Figures/01.png)
6366

64-
# Training
65-
Now, we have all the tools needed to train the model. We first specify the loss function and the optimizer:
67+
# Train the model
68+
69+
To train the model, specify the loss function and the optimizer:
6670

6771
```Python
6872
learning_rate = 1e-3
@@ -71,9 +75,9 @@ loss_fn = nn.CrossEntropyLoss()
7175
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
7276
```
7377

74-
We use CrossEntropyLoss as the loss function and the Adam optimizer for training. The learning rate is set to 1e-3.
78+
Use CrossEntropyLoss as the loss function and the Adam optimizer for training. The learning rate is set to 1e-3.
7579

76-
Next, we define the methods for training and evaluating our feedforward neural network:
80+
Next, define the methods for training and evaluating the feedforward neural network:
7781

7882
```Python
7983
def train_loop(dataloader, model, loss_fn, optimizer):
@@ -105,9 +109,9 @@ def test_loop(dataloader, model, loss_fn):
105109
print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
106110
```
107111

108-
The first method, train_loop, uses the backpropagation algorithm to optimize the trainable parameters and minimize the prediction error of the neural network. The second method, test_loop, calculates the neural network error using the test images and displays the accuracy and loss values.
112+
The first method, `train_loop`, uses the backpropagation algorithm to optimize the trainable parameters and minimize the prediction error of the neural network. The second method, `test_loop`, calculates the neural network error using the test images and displays the accuracy and loss values.
109113

110-
We can now invoke these methods to train and evaluate the model. Similarly to TensorFlow, we use 10 epochs.
114+
You can now invoke these methods to train and evaluate the model using 10 epochs.
111115

112116
```Python
113117
epochs = 10
@@ -131,20 +135,21 @@ Accuracy: 95.4%, Avg loss: 1.507491
131135

132136
which shows the model achieved around 95% of accuracy.
133137

134-
# Saving the model
135-
Once the model is trained, we can save it. There are various approaches for this. In PyTorch, you can save both the model’s structure and its weights to the same file using the torch.save() function. Alternatively, you can save only the weights (parameters) of the model, not the model architecture itself. This requires you to have the model’s architecture defined separately when loading. To save the model weights, you can use the following command:
138+
# Save the model
139+
140+
Once the model is trained, you can save it. There are various approaches for this. In PyTorch, you can save both the model’s structure and its weights to the same file using the `torch.save()` function. Alternatively, you can save only the weights (parameters) of the model, not the model architecture itself. This requires you to have the model’s architecture defined separately when loading. To save the model weights, you can use the following command:
136141

137142
```Python
138143
torch.save(model.state_dict(), "model_weights.pth").
139144
```
140145

141-
However, PyTorch does not save the definition of the class itself. When you load the model using torch.load(), PyTorch needs to know the class definition to recreate the model object.
146+
However, PyTorch does not save the definition of the class itself. When you load the model using `torch.load()`, PyTorch needs to know the class definition to recreate the model object.
142147

143148
Therefore, when you later want to use the saved model for inference, you will need to provide the definition of the model class.
144149

145150
Alternatively, you can use TorchScript, which serializes both the architecture and weights into a single file that can be loaded without needing the original class definition. This is particularly useful for deploying models to production or sharing models without code dependencies.
146151

147-
Here, we use TorchScript and save the model using the following commands:
152+
Use TorchScript to save the model using the following commands:
148153

149154
```Python
150155
# Set model to evaluation mode
@@ -157,15 +162,16 @@ traced_model = torch.jit.trace(model, torch.rand(1, 1, 28, 28))
157162
traced_model.save("model.pth")
158163
```
159164

160-
The above commands set the model to evaluation mode (model.eval()), then trace the model, and save it. Tracing is useful for converting models with static computation graphs to TorchScript, making them portable and independent of the original class definition.
165+
The above commands set the model to evaluation mode, trace the model, and save it. Tracing is useful for converting models with static computation graphs to TorchScript, making them portable and independent of the original class definition.
161166

162167
Setting the model to evaluation mode before tracing is important for several reasons:
168+
163169
1. Behavior of Layers like Dropout and BatchNorm:
164-
* Dropout. During training (model.train()), dropout randomly zeroes out some of the activations to prevent overfitting. During evaluation (model.eval()), dropout is turned off, and all activations are used.
165-
* BatchNorm. During training, Batch Normalization layers use batch statistics to normalize the input. During evaluation, they use running averages calculated during training.
170+
* Dropout. During training, dropout randomly zeroes out some of the activations to prevent overfitting. During evaluation dropout is turned off, and all activations are used.
171+
* BatchNorm. During training, Batch Normalization layers use batch statistics to normalize the input. During evaluation, they use running averages calculated during training.
166172

167-
2. Consistent Inference Behavior. By setting the model to eval() mode, you ensure that the traced model will behave consistently during inference, as it will not use dropout or batch statistics that are inappropriate for inference.
173+
2. Consistent Inference Behavior. By setting the model to eval mode, you ensure that the traced model will behave consistently during inference, as it will not use dropout or batch statistics that are inappropriate for inference.
168174

169175
3. Correct Tracing. Tracing captures the operations performed by the model using a given input. If the model is in training mode, the traced graph may include operations related to dropout and batch normalization updates. These operations can affect the correctness and performance of the model during inference.
170176

171-
In the next step, we will use the saved model for inference.
177+
In the next step, you will use the saved model for inference.

0 commit comments

Comments
 (0)