|
| 1 | +# `JAX Getting Started` Sample |
| 2 | + |
| 3 | +The `JAX Getting Started` sample demonstrates how to train a JAX model and run inference on Intel® hardware. |
| 4 | +| Property | Description |
| 5 | +|:--- |:--- |
| 6 | +| Category | Get Start Sample |
| 7 | +| What you will learn | How to start using JAX* on Intel® hardware. |
| 8 | +| Time to complete | 10 minutes |
| 9 | + |
| 10 | +## Purpose |
| 11 | + |
| 12 | +JAX is a high-performance numerical computing library that enables automatic differentiation. It provides features like just-in-time compilation and efficient parallelization for machine learning and scientific computing tasks. |
| 13 | + |
| 14 | +This sample code shows how to get started with JAX on CPU. The sample code defines a simple neural network that trains on the MNIST dataset using JAX for parallel computations across multiple CPU cores. The network trains over multiple epochs, evaluates accuracy, and adjusts parameters using stochastic gradient descent across devices. |
| 15 | + |
| 16 | +## Prerequisites |
| 17 | + |
| 18 | +| Optimized for | Description |
| 19 | +|:--- |:--- |
| 20 | +| OS | Ubuntu* 22.0.4 and newer |
| 21 | +| Hardware | Intel® Xeon® Scalable processor family |
| 22 | +| Software | JAX |
| 23 | + |
| 24 | +> **Note**: AI and Analytics samples are validated on AI Tools Offline Installer. For the full list of validated platforms refer to [Platform Validation](https://github.com/oneapi-src/oneAPI-samples/tree/master?tab=readme-ov-file#platform-validation). |
| 25 | +
|
| 26 | +## Key Implementation Details |
| 27 | + |
| 28 | +The getting-started sample code uses the python file 'spmd_mnist_classifier_fromscratch.py' under the examples directory in the |
| 29 | +[jax repository](https://github.com/google/jax/). |
| 30 | +It implements a simple neural network's training and inference for mnist images. The images are downloaded to a temporary directory when the example is run first. |
| 31 | +- **init_random_params** initializes the neural network weights and biases for each layer. |
| 32 | +- **predict** computes the forward pass of the network, applying weights, biases, and activations to inputs. |
| 33 | +- **loss** calculates the cross-entropy loss between predictions and target labels. |
| 34 | +- **spmd_update** performs parallel gradient updates across multiple devices using JAX’s pmap and lax.psum. |
| 35 | +- **accuracy** computes the accuracy of the model by predicting the class of each input in the batch and comparing it to the true target class. It uses the *jnp.argmax* function to find the predicted class and then computes the mean of correct predictions. |
| 36 | +- **data_stream** function generates batches of shuffled training data. It reshapes the data so that it can be split across multiple cores, ensuring that the batch size is divisible by the number of cores for parallel processing. |
| 37 | +- **training loop** trains the model for a set number of epochs, updating parameters and printing training/test accuracy after each epoch. The parameters are replicated across devices and updated in parallel using spmd_update. After each epoch, the model’s accuracy is evaluated on both training and test data using accuracy. |
| 38 | + |
| 39 | +## Environment Setup |
| 40 | + |
| 41 | +You will need to download and install the following toolkits, tools, and components to use the sample. |
| 42 | + |
| 43 | +**1. Get Intel® AI Tools** |
| 44 | + |
| 45 | +Required AI Tools: 'JAX' |
| 46 | +<br>If you have not already, select and install these Tools via [AI Tools Selector](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html). AI and Analytics samples are validated on AI Tools Offline Installer. It is recommended to select Offline Installer option in AI Tools Selector.<br> |
| 47 | +please see the [supported versions](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html). |
| 48 | + |
| 49 | +>**Note**: If Docker option is chosen in AI Tools Selector, refer to [Working with Preset Containers](https://github.com/intel/ai-containers/tree/main/preset) to learn how to run the docker and samples. |
| 50 | +
|
| 51 | +**2. (Offline Installer) Activate the AI Tools bundle base environment** |
| 52 | + |
| 53 | +If the default path is used during the installation of AI Tools: |
| 54 | +``` |
| 55 | +source $HOME/intel/oneapi/intelpython/bin/activate |
| 56 | +``` |
| 57 | +If a non-default path is used: |
| 58 | +``` |
| 59 | +source <custom_path>/bin/activate |
| 60 | +``` |
| 61 | + |
| 62 | +**3. (Offline Installer) Activate relevant Conda environment** |
| 63 | + |
| 64 | +For the system with Intel CPU: |
| 65 | +``` |
| 66 | +conda activate jax |
| 67 | +``` |
| 68 | + |
| 69 | +**4. Clone the GitHub repository** |
| 70 | +``` |
| 71 | +git clone https://github.com/google/jax.git |
| 72 | +cd jax |
| 73 | +export PYTHONPATH=$PYTHONPATH:$(pwd) |
| 74 | +``` |
| 75 | +## Run the Sample |
| 76 | + |
| 77 | +>**Note**: Before running the sample, make sure Environment Setup is completed. |
| 78 | +Go to the section which corresponds to the installation method chosen in [AI Tools Selector](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html) to see relevant instructions: |
| 79 | +* [AI Tools Offline Installer (Validated)/Conda/PIP](#ai-tools-offline-installer-validatedcondapip) |
| 80 | +* [Docker](#docker) |
| 81 | +### AI Tools Offline Installer (Validated)/Conda/PIP |
| 82 | +``` |
| 83 | + python examples/spmd_mnist_classifier_fromscratch.py |
| 84 | +``` |
| 85 | +### Docker |
| 86 | +AI Tools Docker images already have Get Started samples pre-installed. Refer to [Working with Preset Containers](https://github.com/intel/ai-containers/tree/main/preset) to learn how to run the docker and samples. |
| 87 | +## Example Output |
| 88 | +1. When the program is run, you should see results similar to the following: |
| 89 | + |
| 90 | +``` |
| 91 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ |
| 92 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ |
| 93 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ |
| 94 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ |
| 95 | +Epoch 0 in 2.71 sec |
| 96 | +Training set accuracy 0.7381166815757751 |
| 97 | +Test set accuracy 0.7516999840736389 |
| 98 | +Epoch 1 in 2.35 sec |
| 99 | +Training set accuracy 0.81454998254776 |
| 100 | +Test set accuracy 0.8277999758720398 |
| 101 | +Epoch 2 in 2.33 sec |
| 102 | +Training set accuracy 0.8448166847229004 |
| 103 | +Test set accuracy 0.8568999767303467 |
| 104 | +Epoch 3 in 2.34 sec |
| 105 | +Training set accuracy 0.8626833558082581 |
| 106 | +Test set accuracy 0.8715999722480774 |
| 107 | +Epoch 4 in 2.30 sec |
| 108 | +Training set accuracy 0.8752999901771545 |
| 109 | +Test set accuracy 0.8816999793052673 |
| 110 | +Epoch 5 in 2.33 sec |
| 111 | +Training set accuracy 0.8839333653450012 |
| 112 | +Test set accuracy 0.8899999856948853 |
| 113 | +Epoch 6 in 2.37 sec |
| 114 | +Training set accuracy 0.8908833265304565 |
| 115 | +Test set accuracy 0.8944999575614929 |
| 116 | +Epoch 7 in 2.31 sec |
| 117 | +Training set accuracy 0.8964999914169312 |
| 118 | +Test set accuracy 0.8986999988555908 |
| 119 | +Epoch 8 in 2.28 sec |
| 120 | +Training set accuracy 0.9016000032424927 |
| 121 | +Test set accuracy 0.9034000039100647 |
| 122 | +Epoch 9 in 2.31 sec |
| 123 | +Training set accuracy 0.9060333371162415 |
| 124 | +Test set accuracy 0.9059999585151672 |
| 125 | +``` |
| 126 | + |
| 127 | +2. Troubleshooting |
| 128 | + |
| 129 | + If you receive an error message, troubleshoot the problem using the **Diagnostics Utility for Intel® oneAPI Toolkits**. The diagnostic utility provides configuration and system checks to help find missing dependencies, permissions errors, and other issues. See the *[Diagnostics Utility for Intel® oneAPI Toolkits User Guide](https://www.intel.com/content/www/us/en/develop/documentation/diagnostic-utility-user-guide/top.html)* for more information on using the utility |
| 130 | + |
| 131 | +## License |
| 132 | + |
| 133 | +Code samples are licensed under the MIT license. See |
| 134 | +[License.txt](https://github.com/oneapi-src/oneAPI-samples/blob/master/License.txt) |
| 135 | +for details. |
| 136 | + |
| 137 | +Third party program Licenses can be found here: |
| 138 | +[third-party-programs.txt](https://github.com/oneapi-src/oneAPI-samples/blob/master/third-party-programs.txt) |
| 139 | + |
| 140 | +*Other names and brands may be claimed as the property of others. [Trademarks](https://www.intel.com/content/www/us/en/legal/trademarks.html) |
0 commit comments