You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/_index.md
+7-7Lines changed: 7 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,23 +1,23 @@
1
1
---
2
-
title: Create and train a PyTorch model for digit classification
2
+
title: Create and train a PyTorch model for digit classification using the MNIST dataset
3
3
4
4
minutes_to_complete: 160
5
5
6
-
who_is_this_for: This is an advanced topic for software developers interested in learning how to use PyTorch to create and train a feedforward neural network for digit classification. You will also learn how to use the trained model in an Android application. Finally, you will apply model optimizations.
6
+
who_is_this_for: This is an advanced topic for software developers interested in learning how to use PyTorch to create and train a feedforward neural network for digit classification, and also software developers interested in learning how to use and apply optimizations to the trained model in an Android application.
7
7
8
8
learning_objectives:
9
9
- Prepare a PyTorch development environment.
10
10
- Download and prepare the MNIST dataset.
11
-
- Create a neural network architecture using PyTorch.
12
-
- Train a neural network using PyTorch.
13
-
- Create an Android app and loading the pre-trained model.
11
+
- Create and train a neural network architecture using PyTorch.
12
+
- Create an Android app and load the pre-trained model.
14
13
- Prepare an input dataset.
15
14
- Measure the inference time.
16
15
- Optimize a neural network architecture using quantization and fusing.
17
-
- Use an optimized model in the Android application.
16
+
- Deploy an optimized model in an Android application.
18
17
19
18
prerequisites:
20
-
- A computer that can run Python3, Visual Studio Code, and Android Studio. The OS can be Windows, Linux, or macOS.
19
+
- A machine that can run Python3, Visual Studio Code, and Android Studio.
20
+
- For the OS, you can use Windows, Linux, or macOS.
Proceed to Use Keras Core with TensorFlow, PyTorch, and JAX backends to continue exploring Machine Learning.
7
+
To continue exploring Maching Learning, you can now learn about using Keras Core with TensorFlow, PyTorch, and JAX backends.
8
8
9
9
# 1-3 sentence recommendation outlining how the reader can generally keep learning about these topics, and a specific explanation of why the next step is being recommended.
Copy file name to clipboardExpand all lines: content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/_review.md
+11-11Lines changed: 11 additions & 11 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -15,31 +15,31 @@ review:
15
15
question: >
16
16
Does the input layer of the model flatten the 28x28 pixel image into a 1D array of 784 elements?
17
17
answers:
18
-
- "Yes"
19
-
- "No"
18
+
- "Yes."
19
+
- "No."
20
20
correct_answer: 1
21
21
explanation: >
22
22
Yes, the model uses nn.Flatten() to reshape the 28x28 pixel image into a 1D array of 784 elements for processing by the fully connected layers.
23
23
- questions:
24
24
question: >
25
-
Will the model make random predictions if it’s run before training?
25
+
Will the model make random predictions if it is run before training?
26
26
answers:
27
-
- "Yes"
28
-
- "No"
27
+
- "Yes."
28
+
- "No."
29
29
correct_answer: 1
30
30
explanation: >
31
-
Yes, however in such the case the model will produce random outputs, as the network has not been trained to recognize any patterns from the data.
31
+
Yes, however in this scenario the model will produce random outputs, as the network has not been trained to recognize any patterns from the data.
32
32
- questions:
33
33
question: >
34
-
Which loss function was used to train the PyTorch model on the MNIST dataset?
34
+
Which loss function did you use to train the PyTorch model on the MNIST dataset in this Learning Path?
35
35
answers:
36
-
- Mean Squared Error Loss
37
-
- CrossEntropy Loss
38
-
- Hinge Loss
36
+
- Mean Squared Error Loss.
37
+
- Cross-Entropy Loss.
38
+
- Hinge Loss.
39
39
- Binary Cross-Entropy Loss
40
40
correct_answer: 2
41
41
explanation: >
42
-
CrossEntropy Loss was used to train the model because it is suitable for multi-class classification tasks like digit classification. It measures the difference between the predicted probabilities and the true class labels, helping the model learn to make accurate predictions.
42
+
Cross-Entropy Loss was used to train the model as it is suitable for multi-class classification such as digit classification. It measures the difference between the predicted probabilities and the true class labels, helping the model to learn to make accurate predictions.
Copy file name to clipboardExpand all lines: content/learning-paths/cross-platform/pytorch-digit-classification-arch-training/app.md
+17-15Lines changed: 17 additions & 15 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -7,29 +7,31 @@ weight: 10
7
7
layout: "learningpathall"
8
8
---
9
9
10
-
You are now ready to run the Android application. You can use an emulator or a physical device.
11
-
12
-
The screenshots below show an emulator.
10
+
You are now ready to run the Android application. The screenshots below show an emulator, but you can also use a physical device.
13
11
14
12
To run the app in Android Studio using an emulator, follow these steps:
15
13
16
14
1. Configure the Emulator:
17
-
* Go to Tools > Device Manager (or click the Device Manager icon on the toolbar).
18
-
* Click Create Device to set up a new virtual device (if you haven’t done so already).
19
-
* Choose a device model, such as Pixel 4, and click Next.
20
-
* Select a system image, such as Android 11, API level 30, and click Next.
21
-
* Review the settings and click Finish to create the emulator.
15
+
16
+
* Go to **Tools** > **Device Manager**, or click the Device Manager icon on the toolbar.
17
+
* Click **Create Device** to set up a new virtual device, if you haven’t done so already.
18
+
* Choose a device model, such as the Pixel 4, and click **Next**.
19
+
* Select a system image, such as Android 11, API level 30, and click **Next**.
20
+
* Review the settings, and click **Finish** to create the emulator.
22
21
23
22
2. Run the App:
24
-
* Make sure the emulator is selected in the device dropdown menu in the toolbar (next to the “Run” button).
25
-
* Click the Run button (a green triangle). Android Studio will build the app, install it on the emulator, and launch it.
26
23
27
-
3. View the App on the Emulator: Once the app is installed, it will automatically open on the emulator screen, allowing you to interact with it as if it were running on a real device.
24
+
* Make sure the emulator is selected in the device drop-down menu in the toolbar, next to the **Run** button.
25
+
* Click the **Run** button, which is a green triangle. Android Studio builds the app, installs it on the emulator, and then launches it.
26
+
27
+
3. View the App on the Emulator:
28
+
29
+
* Once the app is installed, it automatically opens on the emulator screen, allowing you to interact with it as if it were running on a real device.
28
30
29
-
Once the application is started, click the Load Image button. It will load a randomlyselected image. Then, click Run Inference to recognize the digit. The application will display the predicted label and the inference time as shown below:
31
+
Once the application starts, click the **Load Image** button. It loads a randomly-selected image. Then, click **Run Inference** to recognize the digit. The application displays the predicted label and the inference time as shown below:
The above code snippet downloads the MNIST dataset, transforms the images into tensors, and sets up data loaders for training and testing. Specifically, the `datasets.MNIST` function is used to download the MNIST dataset, with `train=True` indicating training data and `train=False` indicating test data. The `transform=transforms.ToTensor()` argument converts each image in the dataset into a PyTorch tensor, which is necessary for model training and evaluation.
45
+
Using this code enables you to:
46
46
47
-
The DataLoader wraps the datasets and allows efficient loading of data in batches. It handles data shuffling, batching, and parallel loading. Here, the train_dataloader and test_dataloader are created with a batch_size of 32, meaning they will load 32 images per batch during training and testing.
47
+
* Download the MNIST dataset.
48
+
* Transform the images into tensors.
49
+
* Set up data loaders for training and testing.
50
+
51
+
Specifically, the `datasets.MNIST` function downloads the MNIST dataset, with `train=True` indicating training data and `train=False` indicating test data. The `transform=transforms.ToTensor()` argument converts each image in the dataset into a PyTorch tensor, which is necessary for model training and evaluation.
52
+
53
+
The DataLoader wraps the datasets and enables efficient loading of data in batches. It handles data shuffling, batching, and parallel loading. Here, the train_dataloader and test_dataloader are created with a batch_size of 32, meaning they will load 32 images per batch during training and testing.
48
54
49
55
This setup prepares the training and test datasets for use in a machine learning model, enabling efficient data handling and model training in PyTorch.
50
56
@@ -54,19 +60,21 @@ To run the above code, you will need to install certifi package:
54
60
pip install certifi
55
61
```
56
62
57
-
The certifi Python package provides the Mozilla root certificates, which are essential for ensuring the SSL connections are secure. If you’re using macOS, you may also need to install the certificates by running:
63
+
The certifi Python package provides the Mozilla root certificates, which are essential for ensuring the SSL connections are secure. If you’re using macOS, you might also need to install the certificates by running:
The first method, `train_loop`, uses the backpropagation algorithm to optimize the trainable parameters and minimize the prediction error of the neural network. The second method, `test_loop`, calculates the neural network error using the test images and displays the accuracy and loss values.
122
+
The first method, `train_loop`, uses the backpropagation algorithm to optimize the trainable parameters and minimize the prediction error rate of the neural network. The second method, `test_loop`, calculates the neural network error rate using the test images, and displays the accuracy and loss values.
115
123
116
124
You can now invoke these methods to train and evaluate the model using 10 epochs.
117
125
@@ -124,9 +132,9 @@ for t in range(epochs):
124
132
test_loop(test_dataloader, model, loss_fn)
125
133
```
126
134
127
-
After running the code, you see the following output showing the training progress.
135
+
After running the code, you see the following output showing the training progress, as displayed in Figure 2.
Once the training is complete, you see output similar to:
132
140
@@ -139,13 +147,13 @@ The output shows the model achieved around 95% accuracy.
139
147
140
148
# Save the model
141
149
142
-
Once the model is trained, you can save it. There are various approaches for this. In PyTorch, you can save both the model’s structure and its weights to the same file using the `torch.save()` function. Alternatively, you can save only the weights (parameters) of the model, not the model architecture itself. This requires you to have the model’s architecture defined separately when loading. To save the model weights, you can use the following command:
150
+
Once the model is trained, you can save it. There are various approaches for this. In PyTorch, you can save both the model’s structure and its weights to the same file using the `torch.save()` function. Alternatively, you can save only the weights of the model, not the model architecture itself. This requires you to have the model’s architecture defined separately when loading. To save the model weights, you can use the following command:
However, PyTorch does not save the definition of the class itself. When you load the model using `torch.load()`, PyTorch needs to know the class definition to recreate the model object.
156
+
However, PyTorch does not save the definition of the class itself. When you load the model using `torch.load()`, PyTorch requires the class definition to recreate the model object.
149
157
150
158
Therefore, when you later want to use the saved model for inference, you will need to provide the definition of the model class.
The above commands set the model to evaluation mode, trace the model, and save it. Tracing is useful for converting models with static computation graphs to TorchScript, making them portable and independent of the original class definition.
175
+
The above commands perform the following tasks:
176
+
177
+
* They set the model to evaluation mode.
178
+
* They trace the model.
179
+
* They save it.
180
+
181
+
Tracing is useful for converting models with static computation graphs to TorchScript, making them flexible and independent of the original class definition.
168
182
169
183
Setting the model to evaluation mode before tracing is important for several reasons:
170
184
171
-
1.Behavior of Layers like Dropout and BatchNorm:
172
-
* Dropout. During training, dropout randomly zeroes out some of the activations to prevent overfitting. During evaluation dropout is turned off, and all activations are used.
185
+
1.Behavior of Layers like Dropout and BatchNorm:
186
+
* Dropout. During training, dropout randomly zeroes out some of the activations to prevent overfitting. During evaluation, dropout is turned off, and all activations are used.
173
187
* BatchNorm. During training, Batch Normalization layers use batch statistics to normalize the input. During evaluation, they use running averages calculated during training.
174
188
175
-
2. Consistent Inference Behavior. By setting the model to eval mode, you ensure that the traced model will behave consistently during inference, as it will not use dropout or batch statistics that are inappropriate for inference.
189
+
2. Consistent Inference Behavior. By setting the model to eval mode, you ensure that the traced model behaves consistently during inference, as it does not use dropout or batch statistics that are inappropriate for inference.
176
190
177
-
3. Correct Tracing. Tracing captures the operations performed by the model using a given input. If the model is in training mode, the traced graph may include operations related to dropout and batch normalization updates. These operations can affect the correctness and performance of the model during inference.
191
+
3. Correct Tracing. Tracing captures the operations performed by the model using a given input. If the model is in training mode, the traced graph might include operations related to dropout and batch normalization updates. These operations can affect the correctness and performance of the model during inference.
178
192
179
193
In the next step, you will use the saved model for ML inference.
0 commit comments