Skip to content

Commit fb95d89

Browse files
DboyqiaoLu Teng
andauthored
[DOC] Add document about accelerated JAX on Intel GPU (#350)
Co-authored-by: Lu Teng <[email protected]>
1 parent c191336 commit fb95d89

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ pip install -r test/requirements.txt
8585
```
8686
Check [test/requirements.txt](test/requirements.txt) for more details.
8787

88+
The following table tracks intel-extension-for-openxla versions and compatible versions of jax and jaxlib. The compatibility between jax and jaxlib is maintained through JAX. This version restriction will be relaxed over time as the plugin API matures.
89+
|**intel-extension-for-openxla**|**jaxlib**|**jax**|
90+
|:-:|:-:|:-:|
91+
| 0.4.0 | 0.4.26 | >= 0.4.26, <= 0.4.27|
92+
| 0.3.0 | 0.4.24 | >= 0.4.24, <= 0.4.27|
93+
| 0.2.1 | 0.4.20 | >= 0.4.20, <= 0.4.26|
94+
| 0.2.0 | 0.4.20 | >= 0.4.20, <= 0.4.26|
95+
| 0.1.0 | 0.4.13 | >= 0.4.13, <= 0.4.14|
96+
8897
## 3. Install
8998

9099
### Install via PyPI wheel

docs/acc_jax.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Accelerated JAX on Intel GPU
2+
3+
## Intel® Extension for OpenXLA* plug-in
4+
Intel® Extension for OpenXLA includes PJRT plugin implementation, which seamlessly runs JAX models on Intel GPU. The PJRT API simplified the integration, which allowed the Intel GPU plugin to be developed separately and quickly integrated into JAX. Refer to [OpenXLA PJRT Plugin RFC](https://github.com/openxla/community/blob/main/rfcs/20230123-pjrt-plugin.md) for more details.
5+
6+
## Requirements
7+
8+
### Hardware Requirements
9+
10+
Verified Hardware Platforms:
11+
12+
* Intel® Data Center GPU Max Series, Driver Version: [803](https://dgpu-docs.intel.com/releases/LTS_803.63_20240617.html) ([Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-max-series))
13+
14+
* Intel® Data Center GPU Flex Series 170, Driver Version: [803](https://dgpu-docs.intel.com/releases/LTS_803.63_20240617.html) ([Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-flex-series))
15+
16+
### Software Requirements
17+
18+
* Ubuntu 22.04 (64-bit)
19+
* Intel® Data Center GPU Flex Series
20+
* Ubuntu 22.04, SUSE Linux Enterprise Server(SLES) 15 SP4
21+
* Intel® Data Center GPU Max Series
22+
* Intel® oneAPI Base Toolkit 2024.1 ([Installation Guides](https://github.com/intel/intel-extension-for-openxla/?tab=readme-ov-file#install-oneapi-base-toolkit-packages))
23+
* Jax/Jaxlib 0.4.26
24+
* Python 3.9-3.12
25+
* pip 19.0 or later (requires manylinux2014 support)
26+
27+
## Install
28+
The following table tracks intel-extension-for-openxla versions and compatible versions of jax and jaxlib. The compatibility between jax and jaxlib is maintained through JAX. This version restriction will be relaxed over time as the plugin API matures.
29+
|**intel-extension-for-openxla**|**jaxlib**|**jax**|
30+
|:-:|:-:|:-:|
31+
| 0.4.0 | 0.4.26 | >= 0.4.26, <= 0.4.27|
32+
| 0.3.0 | 0.4.24 | >= 0.4.24, <= 0.4.27|
33+
| 0.2.1 | 0.4.20 | >= 0.4.20, <= 0.4.26|
34+
| 0.2.0 | 0.4.20 | >= 0.4.20, <= 0.4.26|
35+
| 0.1.0 | 0.4.13 | >= 0.4.13, <= 0.4.14|
36+
37+
[conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) is recommanded as the virtual running environment.
38+
```
39+
conda create -n jax-ioex python=3.10
40+
conda activate jax-ioex
41+
pip install -U pip
42+
pip install jax==0.4.26 jaxlib==0.4.26
43+
pip install intel-extension-for-openxla
44+
```
45+
46+
## Verify
47+
```
48+
python -c "import jax; print(jax.devices())"
49+
```
50+
Reference result:
51+
```
52+
[xpu(id=0), xpu(id=1)]
53+
```
54+
55+
## Example - Run Stable Diffusion Inference
56+
57+
### Install miniforge
58+
```
59+
curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
60+
bash Miniforge3-$(uname)-$(uname -m).sh
61+
```
62+
63+
### Setup environment
64+
```
65+
conda create -n stable-diffusion python=3.10
66+
conda activate stable-diffusion
67+
pip install -U pip
68+
pip install jax==0.4.26 jaxlib==0.4.26 flax==0.8.2
69+
pip install intel-extension-for-openxla
70+
pip install transformers==4.38 diffusers==0.26.3 datasets==2.12.0 msgpack==1.0.7
71+
```
72+
Source OneAPI env
73+
```
74+
source /opt/intel/oneapi/compiler/2024.2/env/vars.sh
75+
source /opt/intel/oneapi/mkl/2024.2/env/vars.sh
76+
```
77+
**NOTE**: The path of OneAPI env script is based on the OneAPI installed path.
78+
79+
80+
### Run Demo (Stable Diffusion Inference)
81+
82+
Go to [example/stable_diffusion](../example/stable_diffusion/README.md) for detail about this demo.
83+
84+
| **Command** | **Model** | **Output Image Resolution** |
85+
| :--- | :---: | :---: |
86+
| ```python jax_stable.py``` | [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) | 512x512 |
87+
| ```python jax_stable.py -m stabilityai/stable-diffusion-2``` | [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) | 768x768 |
88+
| ```python jax_stable.py -m stabilityai/stable-diffusion-2-1``` | [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) | 768x768 |
89+
90+
### Expected result:
91+
```
92+
Average Latency per image is: x.xxx s
93+
Average Throughput per second is: x.xxx steps
94+
```
95+
96+
## Support
97+
To submit questions, feature requests, and bug reports about the intel-extension-for-openxla plugin, visit the [GitHub intel-extension-for-openxla issues](https://github.com/intel/intel-extension-for-openxla/issues) page. You can also view [GitHub JAX Issues](https://github.com/google/jax/issues) with the label "Intel GPU plugin".
98+
99+
## FAQ
100+
101+
1. If there is an error 'No visible XPU devices', print `jax.local_devices()` to check which device is running. Set `export OCL_ICD_ENABLE_TRACE=1` to check if there are driver error messages. The following code opens more debug log for JAX app.
102+
103+
```python
104+
import logging
105+
logging.basicConfig(level = logging.DEBUG)
106+
```
107+
108+
2. If there is an error 'version GLIBCXX_3.4.30' not found, upgrade libstdc++ to the latest, for example for conda
109+
110+
```bash
111+
conda install libstdcxx-ng==12.2.0 -c conda-forge
112+
```

xla/tools/pip_package/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ Intel® Extension for OpenXLA* includes PJRT plugin implementation, which seamle
1010

1111
## Installation
1212

13+
The following table tracks intel-extension-for-openxla versions and compatible versions of jax, jaxlib.
14+
| **intel-extension-for-openxla** | **jaxlib** | **jax** |
15+
|:-:|:-:|:-:|
16+
| 0.4.0 | 0.4.26 | >= 0.4.26, <= 0.4.27|
17+
| 0.3.0 | 0.4.24 | >= 0.4.24, <= 0.4.27|
18+
| 0.2.1 | 0.4.20 | >= 0.4.20, <= 0.4.26|
19+
| 0.2.0 | 0.4.20 | >= 0.4.20, <= 0.4.26|
20+
| 0.1.0 | 0.4.13 | >= 0.4.13, <= 0.4.14|
21+
22+
1323
```
1424
pip install --upgrade intel-extension-for-openxla
1525
```

0 commit comments

Comments
 (0)