|
| 1 | +--- |
| 2 | +# User change |
| 3 | +title: "Using ONNX Runtime" |
| 4 | + |
| 5 | +weight: 3 |
| 6 | + |
| 7 | +layout: "learningpathall" |
| 8 | +--- |
| 9 | + |
| 10 | +## Objective |
| 11 | +Next, you will implement Python code that accomplishes the following tasks: |
| 12 | +* Downloads a pre-trained ONNX model specifically trained on the MNIST dataset, along with the MNIST dataset itself, which is widely used for benchmarking machine learning models. |
| 13 | +* Executes predictions (inference) using the pre-trained ONNX model on test images containing handwritten digits from the MNIST dataset. |
| 14 | +* Evaluates and measures the performance of the inference process, providing insights into the efficiency and speed of the neural network model on your specific system architecture. |
| 15 | + |
| 16 | +This practical demonstration will illustrate the end-to-end workflow of deploying and evaluating ONNX-formatted machine learning models. |
| 17 | + |
| 18 | +## Implementation |
| 19 | + |
| 20 | +### Model |
| 21 | +Create a file named main.py. At the beginning of this file, include the following import statements: |
| 22 | + |
| 23 | +```Python |
| 24 | +import onnxruntime as ort |
| 25 | +import numpy as np |
| 26 | +import matplotlib.pyplot as plt |
| 27 | +import wget, time, os, urllib |
| 28 | +import torchvision |
| 29 | +import torchvision.transforms as transforms |
| 30 | +``` |
| 31 | + |
| 32 | +These statements import the necessary Python libraries: |
| 33 | +* onnxruntime - enables running inference with ONNX models. |
| 34 | +* numpy - facilitates numerical computations and handling of arrays. |
| 35 | +* matplotlib - used for visualizing results such as classification outputs. |
| 36 | +* wget, urllib, and os - provide utilities for downloading files and interacting with the file system. |
| 37 | +* torchvision - allows easy access to datasets like MNIST. |
| 38 | + |
| 39 | +Next, add the following function immediately below the import statements in your main.py file: |
| 40 | + |
| 41 | +```Python |
| 42 | +def download_model(model_name): |
| 43 | + if not os.path.exists(model_name): |
| 44 | + base_url = 'https://github.com/dawidborycki/ONNX.WoA/raw/refs/heads/main/models/' |
| 45 | + url = urllib.parse.urljoin(base_url, model_name) |
| 46 | + wget.download(url) |
| 47 | +``` |
| 48 | + |
| 49 | +This function, download_model, accepts one parameter, model_name. It first checks whether a file with this name already exists in your local directory. If the file does not exist, it downloads the specified ONNX model file from the given GitHub repository URL. This automated check ensures that you won't repeatedly download the model unnecessarily. |
| 50 | + |
| 51 | +### Inference |
| 52 | +Next, you will implement a Python function to perform neural inference. Add the following code to your main.py file below the previously defined download_model function: |
| 53 | + |
| 54 | +```Python |
| 55 | +def onnx_predict(onnx_session, input_name, output_name, |
| 56 | + test_images, test_labels, image_index, show_results): |
| 57 | + |
| 58 | + test_image = np.expand_dims(test_images[image_index], [0,1]) |
| 59 | + |
| 60 | + onnx_pred = onnx_session.run([output_name], {input_name: test_image.astype('float32')}) |
| 61 | + |
| 62 | + predicted_label = np.argmax(np.array(onnx_pred)) |
| 63 | + actual_label = test_labels[image_index] |
| 64 | + |
| 65 | + if show_results: |
| 66 | + plt.figure() |
| 67 | + plt.xticks([]) |
| 68 | + plt.yticks([]) |
| 69 | + plt.imshow(test_images[image_index], cmap=plt.cm.binary) |
| 70 | + |
| 71 | + plt.title('Actual: %s, predicted: %s' |
| 72 | + % (actual_label, predicted_label), fontsize=22) |
| 73 | + plt.show() |
| 74 | + |
| 75 | + return predicted_label, actual_label |
| 76 | +``` |
| 77 | + |
| 78 | +The onnx_predict function prepares a single test image from the dataset by reshaping it to match the input shape expected by the ONNX model, which is (1, 1, 28, 28). This reshaping is achieved using NumPy's expand_dims function. Next, the function performs inference using the ONNX runtime (onnx_session.run). The inference results are probabilities (scores) for each digit class, and the function uses np.argmax to select the digit class with the highest probability, returning it as the predicted label. Optionally, the function visually displays the image along with its actual and predicted labels. |
| 79 | + |
| 80 | +### Performance measurements |
| 81 | +Next, add the following performance-measuring function below onnx_predict in your main.py file: |
| 82 | + |
| 83 | +```Python |
| 84 | +def measure_performance(onnx_session, input_name, output_name, |
| 85 | + test_images, test_labels, execution_count): |
| 86 | + |
| 87 | + start = time.time() |
| 88 | + |
| 89 | + image_indices = np.random.randint(0, test_images.shape[0] - 1, execution_count) |
| 90 | + |
| 91 | + for i in range(1, execution_count): |
| 92 | + onnx_predict(onnx_session, input_name, output_name, |
| 93 | + test_images, test_labels, image_indices[i], False) |
| 94 | + |
| 95 | + computation_time = time.time() - start |
| 96 | + |
| 97 | + print('Computation time: %.3f ms' % (computation_time*1000)) |
| 98 | +``` |
| 99 | + |
| 100 | +This measure_performance function assesses the inference speed by repeatedly invoking the onnx_predict function. It measures the total computation time (in milliseconds) required for the specified number of inference executions (execution_count) and outputs this measurement to the console. |
| 101 | + |
| 102 | +### Putting Everything Together |
| 103 | +Finally, integrate all previously defined functions by adding these statements at the bottom of your main.py file: |
| 104 | +```Python |
| 105 | +if __name__ == "__main__": |
| 106 | + # Download and prepare the model |
| 107 | + model_name = 'mnist-12.onnx' |
| 108 | + download_model(model_name) |
| 109 | + |
| 110 | + # Set up ONNX inference session |
| 111 | + onnx_session = ort.InferenceSession(model_name) |
| 112 | + |
| 113 | + input_name = onnx_session.get_inputs()[0].name |
| 114 | + output_name = onnx_session.get_outputs()[0].name |
| 115 | + |
| 116 | + # Load the MNIST dataset using torchvision |
| 117 | + transform = transforms.Compose([transforms.ToTensor()]) |
| 118 | + mnist_dataset = torchvision.datasets.MNIST(root='./data', train=False, |
| 119 | + download=True, transform=transform) |
| 120 | + |
| 121 | + test_images = mnist_dataset.data.numpy() |
| 122 | + test_labels = mnist_dataset.targets.numpy() |
| 123 | + |
| 124 | + # Normalize images |
| 125 | + test_images = test_images / 255.0 |
| 126 | + |
| 127 | + # Perform a single prediction and display the result |
| 128 | + image_index = np.random.randint(0, test_images.shape[0] - 1) |
| 129 | + onnx_predict(onnx_session, input_name, output_name, |
| 130 | + test_images, test_labels, image_index, True) |
| 131 | + |
| 132 | + # Measure inference performance |
| 133 | + measure_performance(onnx_session, input_name, output_name, |
| 134 | + test_images, test_labels, execution_count=1000) |
| 135 | +``` |
| 136 | + |
| 137 | +This script first initializes an ONNX inference session with the downloaded model (mnist-12.onnx). It then retrieves the model's input and output details, loads the MNIST dataset for testing, runs a sample inference showing visual results, and finally measures the performance of the inference operation over multiple runs. |
| 138 | + |
| 139 | +## Summary |
| 140 | +In this section, you implemented Python code to download a pre-trained ONNX model and the MNIST dataset, perform inference to recognize handwritten digits, and measure inference performance. In the next step, you will install all required dependencies and run the code. |
0 commit comments