You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: AI-and-Analytics/Getting-Started-Samples/IntelJAX_GettingStarted/README.md
+4-3Lines changed: 4 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -11,7 +11,7 @@ The `JAX Getting Started` sample demonstrates how to train a JAX model and run i
11
11
12
12
JAX is a high-performance numerical computing library that enables automatic differentiation. It provides features like just-in-time compilation and efficient parallelization for machine learning and scientific computing tasks.
13
13
14
-
This sample code shows how to get started with JAX in CPU. The sample code defines a simple neural network that trains on the MNIST dataset using JAX for parallel computations across multiple CPU cores. The network trains over multiple epochs, evaluates accuracy, and adjusts parameters using stochastic gradient descent across devices.
14
+
This sample code shows how to get started with JAX on CPU. The sample code defines a simple neural network that trains on the MNIST dataset using JAX for parallel computations across multiple CPU cores. The network trains over multiple epochs, evaluates accuracy, and adjusts parameters using stochastic gradient descent across devices.
15
15
16
16
## Prerequisites
17
17
@@ -25,7 +25,8 @@ This sample code shows how to get started with JAX in CPU. The sample code defin
25
25
26
26
## Key Implementation Details
27
27
28
-
The example implementation involves a python file 'spmd_mnist_classifier_fromscratch.py' under the examples directory from the jax repo [(https://github.com/google/jax/)].
28
+
The getting-started sample code uses the python file 'spmd_mnist_classifier_fromscratch.py' under the examples directory in the
29
+
[jax repository](https://github.com/google/jax/).
29
30
It implements a simple neural network's training and inference for mnist images. The images are downloaded to a temporary directory when the example is run first.
30
31
-**init_random_params** initializes the neural network weights and biases for each layer.
31
32
-**predict** computes the forward pass of the network, applying weights, biases, and activations to inputs.
@@ -84,7 +85,7 @@ Go to the section which corresponds to the installation method chosen in [AI Too
84
85
### Docker
85
86
AI Tools Docker images already have Get Started samples pre-installed. Refer to [Working with Preset Containers](https://github.com/intel/ai-containers/tree/main/preset) to learn how to run the docker and samples.
86
87
## Example Output
87
-
1.With the initial run, you should see results similar to the following:
88
+
1.When the program is run, you should see results similar to the following:
88
89
89
90
```
90
91
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
0 commit comments