Skip to content

Commit 0172a65

Browse files
Merge pull request #1483 from madeline-underwood/MNIST-digit-classification
Mnist digit classification LP_KB to review
2 parents 207e747 + 68847e5 commit 0172a65

File tree

16 files changed

+279
-198
lines changed

16 files changed

+279
-198
lines changed

content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/_index.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
---
2-
title: Create and train a PyTorch model for digit classification
2+
title: Create and train a PyTorch model for digit classification using the MNIST dataset
33

44
minutes_to_complete: 160
55

6-
who_is_this_for: This is an advanced topic for software developers interested in learning how to use PyTorch to create and train a feedforward neural network for digit classification. You will also learn how to use the trained model in an Android application. Finally, you will apply model optimizations.
6+
who_is_this_for: This is an advanced topic for software developers interested in learning how to use PyTorch to create and train a feedforward neural network for digit classification, and also software developers interested in learning how to use and apply optimizations to the trained model in an Android application.
77

88
learning_objectives:
99
- Prepare a PyTorch development environment.
1010
- Download and prepare the MNIST dataset.
11-
- Create a neural network architecture using PyTorch.
12-
- Train a neural network using PyTorch.
13-
- Create an Android app and loading the pre-trained model.
11+
- Create and train a neural network architecture using PyTorch.
12+
- Create an Android app and load the pre-trained model.
1413
- Prepare an input dataset.
1514
- Measure the inference time.
1615
- Optimize a neural network architecture using quantization and fusing.
17-
- Use an optimized model in the Android application.
16+
- Deploy an optimized model in an Android application.
1817

1918
prerequisites:
20-
- A computer that can run Python3, Visual Studio Code, and Android Studio. The OS can be Windows, Linux, or macOS.
19+
- A machine that can run Python3, Visual Studio Code, and Android Studio.
20+
- For the OS, you can use Windows, Linux, or macOS.
2121

2222

2323
author_primary: Dawid Borycki

content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/_next-steps.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# ================================================================================
55

66
next_step_guidance: >
7-
Proceed to Use Keras Core with TensorFlow, PyTorch, and JAX backends to continue exploring Machine Learning.
7+
To continue exploring Maching Learning, you can now learn about using Keras Core with TensorFlow, PyTorch, and JAX backends.
88
99
# 1-3 sentence recommendation outlining how the reader can generally keep learning about these topics, and a specific explanation of why the next step is being recommended.
1010

content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/_review.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@ review:
1515
question: >
1616
Does the input layer of the model flatten the 28x28 pixel image into a 1D array of 784 elements?
1717
answers:
18-
- "Yes"
19-
- "No"
18+
- "Yes."
19+
- "No."
2020
correct_answer: 1
2121
explanation: >
2222
Yes, the model uses nn.Flatten() to reshape the 28x28 pixel image into a 1D array of 784 elements for processing by the fully connected layers.
2323
- questions:
2424
question: >
25-
Will the model make random predictions if it’s run before training?
25+
Will the model make random predictions if it is run before training?
2626
answers:
27-
- "Yes"
28-
- "No"
27+
- "Yes."
28+
- "No."
2929
correct_answer: 1
3030
explanation: >
31-
Yes, however in such the case the model will produce random outputs, as the network has not been trained to recognize any patterns from the data.
31+
Yes, however in this scenario the model will produce random outputs, as the network has not been trained to recognize any patterns from the data.
3232
- questions:
3333
question: >
34-
Which loss function was used to train the PyTorch model on the MNIST dataset?
34+
Which loss function did you use to train the PyTorch model on the MNIST dataset in this Learning Path?
3535
answers:
36-
- Mean Squared Error Loss
37-
- Cross Entropy Loss
38-
- Hinge Loss
36+
- Mean Squared Error Loss.
37+
- Cross-Entropy Loss.
38+
- Hinge Loss.
3939
- Binary Cross-Entropy Loss
4040
correct_answer: 2
4141
explanation: >
42-
Cross Entropy Loss was used to train the model because it is suitable for multi-class classification tasks like digit classification. It measures the difference between the predicted probabilities and the true class labels, helping the model learn to make accurate predictions.
42+
Cross-Entropy Loss was used to train the model as it is suitable for multi-class classification such as digit classification. It measures the difference between the predicted probabilities and the true class labels, helping the model to learn to make accurate predictions.
4343
4444
# ================================================================================
4545
# FIXED, DO NOT MODIFY

content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/app.md

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,31 @@ weight: 10
77
layout: "learningpathall"
88
---
99

10-
You are now ready to run the Android application. You can use an emulator or a physical device.
11-
12-
The screenshots below show an emulator.
10+
You are now ready to run the Android application. The screenshots below show an emulator, but you can also use a physical device.
1311

1412
To run the app in Android Studio using an emulator, follow these steps:
1513

1614
1. Configure the Emulator:
17-
* Go to Tools > Device Manager (or click the Device Manager icon on the toolbar).
18-
* Click Create Device to set up a new virtual device (if you haven’t done so already).
19-
* Choose a device model, such as Pixel 4, and click Next.
20-
* Select a system image, such as Android 11, API level 30, and click Next.
21-
* Review the settings and click Finish to create the emulator.
15+
16+
* Go to **Tools** > **Device Manager**, or click the Device Manager icon on the toolbar.
17+
* Click **Create Device** to set up a new virtual device, if you haven’t done so already.
18+
* Choose a device model, such as the Pixel 4, and click **Next**.
19+
* Select a system image, such as Android 11, API level 30, and click **Next**.
20+
* Review the settings, and click **Finish** to create the emulator.
2221

2322
2. Run the App:
24-
* Make sure the emulator is selected in the device dropdown menu in the toolbar (next to the “Run” button).
25-
* Click the Run button (a green triangle). Android Studio will build the app, install it on the emulator, and launch it.
2623

27-
3. View the App on the Emulator: Once the app is installed, it will automatically open on the emulator screen, allowing you to interact with it as if it were running on a real device.
24+
* Make sure the emulator is selected in the device drop-down menu in the toolbar, next to the **Run** button.
25+
* Click the **Run** button, which is a green triangle. Android Studio builds the app, installs it on the emulator, and then launches it.
26+
27+
3. View the App on the Emulator:
28+
29+
* Once the app is installed, it automatically opens on the emulator screen, allowing you to interact with it as if it were running on a real device.
2830

29-
Once the application is started, click the Load Image button. It will load a randomly selected image. Then, click Run Inference to recognize the digit. The application will display the predicted label and the inference time as shown below:
31+
Once the application starts, click the **Load Image** button. It loads a randomly-selected image. Then, click **Run Inference** to recognize the digit. The application displays the predicted label and the inference time as shown below:
3032

31-
![img](Figures/05.png)
33+
![img alt-text#center](Figures/05.png "Figure 7. Digit Recognition 1")
3234

33-
![img](Figures/06.png)
35+
![img alt-text#center](Figures/06.png "Figure 8. Digit Recognition 2")
3436

35-
In the next step you will learn how to further optimize the model.
37+
In the next step of this Learning Path, you will learn how to further optimize the model.

content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/datasets-and-training.md

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
# User change
3-
title: "Perform training and save the model"
3+
title: "Perform Training and Save the Model"
44

55
weight: 5
66

@@ -9,9 +9,9 @@ layout: "learningpathall"
99

1010
## Prepare the MNIST data
1111

12-
Start by downloading the MNIST dataset. Proceed as follows:
12+
Start by downloading the MNIST dataset.
1313

14-
1. Open the pytorch-digits.ipynb you created earlier.
14+
1. Open the `pytorch-digits.ipynb` you created earlier.
1515

1616
2. Add the following statements:
1717

@@ -42,9 +42,15 @@ train_dataloader = DataLoader(training_data, batch_size=batch_size)
4242
test_dataloader = DataLoader(test_data, batch_size=batch_size)
4343
```
4444

45-
The above code snippet downloads the MNIST dataset, transforms the images into tensors, and sets up data loaders for training and testing. Specifically, the `datasets.MNIST` function is used to download the MNIST dataset, with `train=True` indicating training data and `train=False` indicating test data. The `transform=transforms.ToTensor()` argument converts each image in the dataset into a PyTorch tensor, which is necessary for model training and evaluation.
45+
Using this code enables you to:
4646

47-
The DataLoader wraps the datasets and allows efficient loading of data in batches. It handles data shuffling, batching, and parallel loading. Here, the train_dataloader and test_dataloader are created with a batch_size of 32, meaning they will load 32 images per batch during training and testing.
47+
* Download the MNIST dataset.
48+
* Transform the images into tensors.
49+
* Set up data loaders for training and testing.
50+
51+
Specifically, the `datasets.MNIST` function downloads the MNIST dataset, with `train=True` indicating training data and `train=False` indicating test data. The `transform=transforms.ToTensor()` argument converts each image in the dataset into a PyTorch tensor, which is necessary for model training and evaluation.
52+
53+
The DataLoader wraps the datasets and enables efficient loading of data in batches. It handles data shuffling, batching, and parallel loading. Here, the train_dataloader and test_dataloader are created with a batch_size of 32, meaning they will load 32 images per batch during training and testing.
4854

4955
This setup prepares the training and test datasets for use in a machine learning model, enabling efficient data handling and model training in PyTorch.
5056

@@ -54,19 +60,21 @@ To run the above code, you will need to install certifi package:
5460
pip install certifi
5561
```
5662

57-
The certifi Python package provides the Mozilla root certificates, which are essential for ensuring the SSL connections are secure. If you’re using macOS, you may also need to install the certificates by running:
63+
The certifi Python package provides the Mozilla root certificates, which are essential for ensuring the SSL connections are secure. If you’re using macOS, you might also need to install the certificates by running:
5864

5965
```console
6066
/Applications/Python\ 3.x/Install\ Certificates.command
6167
```
6268

63-
Make sure to replace `x` with the number of Python version you have installed.
69+
{{% notice Note %}}
70+
Make sure to replace 'x' with the version number of Python that you have installed.
71+
{{% /notice %}}
6472

65-
After running the code you see output similar to the screenshot below:
73+
After running the code, you will see output similar to Figure 5:
6674

67-
![image](Figures/01.png)
75+
![image alt-text#center](Figures/01.png "Figure 5. Output".)
6876

69-
# Train the model
77+
## Train the Model
7078

7179
To train the model, specify the loss function and the optimizer:
7280

@@ -77,7 +85,7 @@ loss_fn = nn.CrossEntropyLoss()
7785
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
7886
```
7987

80-
Use CrossEntropyLoss as the loss function and the Adam optimizer for training. The learning rate is set to 1e-3.
88+
Use `CrossEntropyLoss` as the loss function and the Adam optimizer for training. The learning rate is set to 1e-3.
8189

8290
Next, define the methods for training and evaluating the feedforward neural network:
8391

@@ -111,7 +119,7 @@ def test_loop(dataloader, model, loss_fn):
111119
print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
112120
```
113121

114-
The first method, `train_loop`, uses the backpropagation algorithm to optimize the trainable parameters and minimize the prediction error of the neural network. The second method, `test_loop`, calculates the neural network error using the test images and displays the accuracy and loss values.
122+
The first method, `train_loop`, uses the backpropagation algorithm to optimize the trainable parameters and minimize the prediction error rate of the neural network. The second method, `test_loop`, calculates the neural network error rate using the test images, and displays the accuracy and loss values.
115123

116124
You can now invoke these methods to train and evaluate the model using 10 epochs.
117125

@@ -124,9 +132,9 @@ for t in range(epochs):
124132
test_loop(test_dataloader, model, loss_fn)
125133
```
126134

127-
After running the code, you see the following output showing the training progress.
135+
After running the code, you see the following output showing the training progress, as displayed in Figure 2.
128136

129-
![image](Figures/02.png)
137+
![image alt-text#center](Figures/02.png "Figure 2. Output 2")
130138

131139
Once the training is complete, you see output similar to:
132140

@@ -139,13 +147,13 @@ The output shows the model achieved around 95% accuracy.
139147

140148
# Save the model
141149

142-
Once the model is trained, you can save it. There are various approaches for this. In PyTorch, you can save both the model’s structure and its weights to the same file using the `torch.save()` function. Alternatively, you can save only the weights (parameters) of the model, not the model architecture itself. This requires you to have the model’s architecture defined separately when loading. To save the model weights, you can use the following command:
150+
Once the model is trained, you can save it. There are various approaches for this. In PyTorch, you can save both the model’s structure and its weights to the same file using the `torch.save()` function. Alternatively, you can save only the weights of the model, not the model architecture itself. This requires you to have the model’s architecture defined separately when loading. To save the model weights, you can use the following command:
143151

144152
```Python
145153
torch.save(model.state_dict(), "model_weights.pth").
146154
```
147155

148-
However, PyTorch does not save the definition of the class itself. When you load the model using `torch.load()`, PyTorch needs to know the class definition to recreate the model object.
156+
However, PyTorch does not save the definition of the class itself. When you load the model using `torch.load()`, PyTorch requires the class definition to recreate the model object.
149157

150158
Therefore, when you later want to use the saved model for inference, you will need to provide the definition of the model class.
151159

@@ -164,16 +172,22 @@ traced_model = torch.jit.trace(model, torch.rand(1, 1, 28, 28))
164172
traced_model.save("model.pth")
165173
```
166174

167-
The above commands set the model to evaluation mode, trace the model, and save it. Tracing is useful for converting models with static computation graphs to TorchScript, making them portable and independent of the original class definition.
175+
The above commands perform the following tasks:
176+
177+
* They set the model to evaluation mode.
178+
* They trace the model.
179+
* They save it.
180+
181+
Tracing is useful for converting models with static computation graphs to TorchScript, making them flexible and independent of the original class definition.
168182

169183
Setting the model to evaluation mode before tracing is important for several reasons:
170184

171-
1. Behavior of Layers like Dropout and BatchNorm:
172-
* Dropout. During training, dropout randomly zeroes out some of the activations to prevent overfitting. During evaluation dropout is turned off, and all activations are used.
185+
1. Behavior of Layers like Dropout and BatchNorm:
186+
* Dropout. During training, dropout randomly zeroes out some of the activations to prevent overfitting. During evaluation, dropout is turned off, and all activations are used.
173187
* BatchNorm. During training, Batch Normalization layers use batch statistics to normalize the input. During evaluation, they use running averages calculated during training.
174188

175-
2. Consistent Inference Behavior. By setting the model to eval mode, you ensure that the traced model will behave consistently during inference, as it will not use dropout or batch statistics that are inappropriate for inference.
189+
2. Consistent Inference Behavior. By setting the model to eval mode, you ensure that the traced model behaves consistently during inference, as it does not use dropout or batch statistics that are inappropriate for inference.
176190

177-
3. Correct Tracing. Tracing captures the operations performed by the model using a given input. If the model is in training mode, the traced graph may include operations related to dropout and batch normalization updates. These operations can affect the correctness and performance of the model during inference.
191+
3. Correct Tracing. Tracing captures the operations performed by the model using a given input. If the model is in training mode, the traced graph might include operations related to dropout and batch normalization updates. These operations can affect the correctness and performance of the model during inference.
178192

179193
In the next step, you will use the saved model for ML inference.

0 commit comments

Comments
 (0)