Skip to content

Commit a694251

Browse files
authored
Merge pull request #2507 from intel-ai-tce/jax_getting_started
Jax getting started
2 parents 26333fa + ee3a621 commit a694251

File tree

6 files changed

+430
-0
lines changed

6 files changed

+430
-0
lines changed

AI-and-Analytics/Getting-Started-Samples/IntelJAX_GettingStarted/.gitkeep

Whitespace-only changes.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Copyright Intel Corporation
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4+
5+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6+
7+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# `JAX Getting Started` Sample
2+
3+
The `JAX Getting Started` sample demonstrates how to train a JAX model and run inference on Intel® hardware.
4+
| Property | Description
5+
|:--- |:---
6+
| Category | Get Start Sample
7+
| What you will learn | How to start using JAX* on Intel® hardware.
8+
| Time to complete | 10 minutes
9+
10+
## Purpose
11+
12+
JAX is a high-performance numerical computing library that enables automatic differentiation. It provides features like just-in-time compilation and efficient parallelization for machine learning and scientific computing tasks.
13+
14+
This sample code shows how to get started with JAX on CPU. The sample code defines a simple neural network that trains on the MNIST dataset using JAX for parallel computations across multiple CPU cores. The network trains over multiple epochs, evaluates accuracy, and adjusts parameters using stochastic gradient descent across devices.
15+
16+
## Prerequisites
17+
18+
| Optimized for | Description
19+
|:--- |:---
20+
| OS | Ubuntu* 22.0.4 and newer
21+
| Hardware | Intel® Xeon® Scalable processor family
22+
| Software | JAX
23+
24+
> **Note**: AI and Analytics samples are validated on AI Tools Offline Installer. For the full list of validated platforms refer to [Platform Validation](https://github.com/oneapi-src/oneAPI-samples/tree/master?tab=readme-ov-file#platform-validation).
25+
26+
## Key Implementation Details
27+
28+
The getting-started sample code uses the python file 'spmd_mnist_classifier_fromscratch.py' under the examples directory in the
29+
[jax repository](https://github.com/google/jax/).
30+
It implements a simple neural network's training and inference for mnist images. The images are downloaded to a temporary directory when the example is run first.
31+
- **init_random_params** initializes the neural network weights and biases for each layer.
32+
- **predict** computes the forward pass of the network, applying weights, biases, and activations to inputs.
33+
- **loss** calculates the cross-entropy loss between predictions and target labels.
34+
- **spmd_update** performs parallel gradient updates across multiple devices using JAX’s pmap and lax.psum.
35+
- **accuracy** computes the accuracy of the model by predicting the class of each input in the batch and comparing it to the true target class. It uses the *jnp.argmax* function to find the predicted class and then computes the mean of correct predictions.
36+
- **data_stream** function generates batches of shuffled training data. It reshapes the data so that it can be split across multiple cores, ensuring that the batch size is divisible by the number of cores for parallel processing.
37+
- **training loop** trains the model for a set number of epochs, updating parameters and printing training/test accuracy after each epoch. The parameters are replicated across devices and updated in parallel using spmd_update. After each epoch, the model’s accuracy is evaluated on both training and test data using accuracy.
38+
39+
## Environment Setup
40+
41+
You will need to download and install the following toolkits, tools, and components to use the sample.
42+
43+
**1. Get Intel® AI Tools**
44+
45+
Required AI Tools: 'JAX'
46+
<br>If you have not already, select and install these Tools via [AI Tools Selector](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html). AI and Analytics samples are validated on AI Tools Offline Installer. It is recommended to select Offline Installer option in AI Tools Selector.<br>
47+
please see the [supported versions](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html).
48+
49+
>**Note**: If Docker option is chosen in AI Tools Selector, refer to [Working with Preset Containers](https://github.com/intel/ai-containers/tree/main/preset) to learn how to run the docker and samples.
50+
51+
**2. (Offline Installer) Activate the AI Tools bundle base environment**
52+
53+
If the default path is used during the installation of AI Tools:
54+
```
55+
source $HOME/intel/oneapi/intelpython/bin/activate
56+
```
57+
If a non-default path is used:
58+
```
59+
source <custom_path>/bin/activate
60+
```
61+
62+
**3. (Offline Installer) Activate relevant Conda environment**
63+
64+
For the system with Intel CPU:
65+
```
66+
conda activate jax
67+
```
68+
69+
**4. Clone the GitHub repository**
70+
```
71+
git clone https://github.com/google/jax.git
72+
cd jax
73+
export PYTHONPATH=$PYTHONPATH:$(pwd)
74+
```
75+
## Run the Sample
76+
77+
>**Note**: Before running the sample, make sure Environment Setup is completed.
78+
Go to the section which corresponds to the installation method chosen in [AI Tools Selector](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html) to see relevant instructions:
79+
* [AI Tools Offline Installer (Validated)/Conda/PIP](#ai-tools-offline-installer-validatedcondapip)
80+
* [Docker](#docker)
81+
### AI Tools Offline Installer (Validated)/Conda/PIP
82+
```
83+
python examples/spmd_mnist_classifier_fromscratch.py
84+
```
85+
### Docker
86+
AI Tools Docker images already have Get Started samples pre-installed. Refer to [Working with Preset Containers](https://github.com/intel/ai-containers/tree/main/preset) to learn how to run the docker and samples.
87+
## Example Output
88+
1. When the program is run, you should see results similar to the following:
89+
90+
```
91+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
92+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
93+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
94+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/
95+
Epoch 0 in 2.71 sec
96+
Training set accuracy 0.7381166815757751
97+
Test set accuracy 0.7516999840736389
98+
Epoch 1 in 2.35 sec
99+
Training set accuracy 0.81454998254776
100+
Test set accuracy 0.8277999758720398
101+
Epoch 2 in 2.33 sec
102+
Training set accuracy 0.8448166847229004
103+
Test set accuracy 0.8568999767303467
104+
Epoch 3 in 2.34 sec
105+
Training set accuracy 0.8626833558082581
106+
Test set accuracy 0.8715999722480774
107+
Epoch 4 in 2.30 sec
108+
Training set accuracy 0.8752999901771545
109+
Test set accuracy 0.8816999793052673
110+
Epoch 5 in 2.33 sec
111+
Training set accuracy 0.8839333653450012
112+
Test set accuracy 0.8899999856948853
113+
Epoch 6 in 2.37 sec
114+
Training set accuracy 0.8908833265304565
115+
Test set accuracy 0.8944999575614929
116+
Epoch 7 in 2.31 sec
117+
Training set accuracy 0.8964999914169312
118+
Test set accuracy 0.8986999988555908
119+
Epoch 8 in 2.28 sec
120+
Training set accuracy 0.9016000032424927
121+
Test set accuracy 0.9034000039100647
122+
Epoch 9 in 2.31 sec
123+
Training set accuracy 0.9060333371162415
124+
Test set accuracy 0.9059999585151672
125+
```
126+
127+
2. Troubleshooting
128+
129+
If you receive an error message, troubleshoot the problem using the **Diagnostics Utility for Intel® oneAPI Toolkits**. The diagnostic utility provides configuration and system checks to help find missing dependencies, permissions errors, and other issues. See the *[Diagnostics Utility for Intel® oneAPI Toolkits User Guide](https://www.intel.com/content/www/us/en/develop/documentation/diagnostic-utility-user-guide/top.html)* for more information on using the utility
130+
131+
## License
132+
133+
Code samples are licensed under the MIT license. See
134+
[License.txt](https://github.com/oneapi-src/oneAPI-samples/blob/master/License.txt)
135+
for details.
136+
137+
Third party program Licenses can be found here:
138+
[third-party-programs.txt](https://github.com/oneapi-src/oneAPI-samples/blob/master/third-party-programs.txt)
139+
140+
*Other names and brands may be claimed as the property of others. [Trademarks](https://www.intel.com/content/www/us/en/legal/trademarks.html)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
source $HOME/intel/oneapi/intelpython/bin/activate
2+
conda activate jax
3+
git clone https://github.com/google/jax.git
4+
cd jax
5+
export PYTHONPATH=$PYTHONPATH:$(pwd)
6+
python examples/spmd_mnist_classifier_fromscratch.py
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"guid": "9A6A140B-FBD0-4CB2-849A-9CAF15A6F3B1",
3+
"name": "Getting Started example for JAX CPU",
4+
"categories": ["Toolkit/oneAPI AI And Analytics/Getting Started"],
5+
"description": "This sample illustrates how to train a JAX model and run inference",
6+
"builder": ["cli"],
7+
"languages": [{
8+
"python": {}
9+
}],
10+
"os": ["linux"],
11+
"targetDevice": ["CPU"],
12+
"ciTests": {
13+
"linux": [{
14+
"id": "JAX CPU example",
15+
"steps": [
16+
"git clone https://github.com/google/jax.git",
17+
"cd jax",
18+
"conda activate jax",
19+
"python examples/spmd_mnist_classifier_fromscratch.py"
20+
]
21+
}]
22+
},
23+
"expertise": "Getting Started"
24+
}

0 commit comments

Comments
 (0)