Skip to content

Commit a673c95

Browse files
authored
Rebase with jax-v0.4.20 (#119)
1 parent 8858e32 commit a673c95

File tree

126 files changed

+7528
-19667
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+7528
-19667
lines changed

.bazelrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ build:gpu --define=using_sycl=true
2424
build:gpu --define=device=gpu
2525
build:gpu --config=onednn_v3
2626
build:gpu --repo_env TF_NEED_SYCL=1
27+
build:gpu --define=tensorflow_mkldnn_contraction_kernel=0
2728

2829
# This config build with oneDNN V3 API, which is enabled by default
2930
build:onednn_v3 --copt=-DITEX_ONEDNN_3_0 --define=onednn_version=3

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Verified Hardware Platforms:
3737
* Ubuntu 22.04, Red Hat 8.6/8.8/9.2 (64-bit), SUSE Linux Enterprise Server(SLES) 15 SP4
3838
* Intel® Data Center GPU Max Series
3939
* Intel® oneAPI Base Toolkit 2023.2
40-
* Jax/Jaxlib 0.4.13
40+
* Jax/Jaxlib 0.4.20
4141
* Python 3.9-3.11
4242
* pip 19.0 or later (requires manylinux2014 support)
4343

@@ -69,7 +69,7 @@ source /opt/intel/oneapi/tbb/2021.9.0/env/vars.sh
6969
### Install Jax and Jaxlib
7070

7171
```bash
72-
pip install jax==0.4.13 jaxlib==0.4.13
72+
pip install jax==0.4.20 jaxlib==0.4.20
7373
```
7474

7575
## 3. Install

WORKSPACE

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ http_archive(
1414
name = "xla",
1515
patch_args = ["-p1"],
1616
patches = ["//third_party:openxla.patch"],
17-
sha256 = "4ec16aff3862c5a243db956ce558d7a62eb79f5e20747b0e80802a3b0d12e419",
18-
strip_prefix = "xla-12de6ec958419b57be248d0acd2d9f757e71748c",
17+
sha256 = "15235814a637f2199d2adae05c09dc8f6655737f06713108f970749bc5e87b46",
18+
strip_prefix = "xla-ca31652cdbeb6ea187589dea546ff8019274f8b2",
1919
urls = [
20-
"https://github.com/openxla/xla/archive/12de6ec958419b57be248d0acd2d9f757e71748c.tar.gz",
20+
"https://github.com/openxla/xla/archive/ca31652cdbeb6ea187589dea546ff8019274f8b2.tar.gz",
2121
],
2222
)
2323

@@ -62,4 +62,4 @@ bazel_toolchains_repositories()
6262

6363
load("//xla:workspace.bzl", "workspace")
6464

65-
workspace()
65+
workspace()

example/bert/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Please got the [main page](https://github.com/intel/intel-extension-for-openxla/
1010
### 2. Install dependency
1111
Mark `intel-extension-for-openxla` folder as \<WORKSPACE\>, then
1212
```bash
13-
pip install jax==0.4.13 jaxlib==0.4.13 flax==0.7.0
13+
pip install jax==0.4.20 jaxlib==0.4.20 flax==0.7.0
1414
cd <WORKSPACE>/example/bert
1515
pip install -r requirements.txt
1616
```

example/gptj/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Script jax_gptj.py for [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/g
55
## Prerequisites
66

77
```bash
8-
pip install jax==0.4.13 jaxlib==0.4.13 flax==0.7.0 transformers==4.27.4 diffusers==0.16.1 datasets==2.12.0
8+
pip install jax==0.4.20 jaxlib==0.4.20 flax==0.7.0 transformers==4.32 diffusers==0.16.1 datasets==2.12.0
99
```
1010

1111
## Options

example/resnet50/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Please got the [main page](https://github.com/intel/intel-extension-for-openxla/
1010

1111
### 2. Install jax and flax
1212
```bash
13-
pip install jax==0.4.13 jaxlib==0.4.13
13+
pip install jax==0.4.20 jaxlib==0.4.20
1414
```
1515
### 3. Install dependency
1616
```bash

example/stable_diffusion/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ please got the [main page](https://github.com/intel/intel-extension-for-openxla/
1010

1111
### 2. Install jax
1212
```bash
13-
pip install jax==0.4.13 jaxlib==0.4.13 flax==0.7.0
13+
pip install jax==0.4.20 jaxlib==0.4.20 flax==0.7.0
1414
```
1515
### 3. Install huggingface transformers
1616

1717
```bash
18-
pip install transformers==4.27.4 diffusers==0.16.1 datasets==2.12.0 msgpack==1.0.7
18+
pip install transformers==4.32 diffusers==0.16.1 datasets==2.12.0 msgpack==1.0.7
1919
```
2020
## Run
2121

@@ -24,6 +24,7 @@ pip install transformers==4.27.4 diffusers==0.16.1 datasets==2.12.0 msgpack==1.0
2424
| **ENV** | **Description** | **PVC Platform** | **ATSM/DG2 Platform** |
2525
| :---: | :---: | :---: |:---: |
2626
| ZE_AFFINITY_MASK | Run this model on single GPU tile |export ZE_AFFINITY_MASK=0.0 | export ZE_AFFINITY_MASK=0.0 |
27+
| XLA_FLAGS | Cutomerlize xla debug options | export XLA_FLAGS="--xla_gpu_force_conv_nhwc" | export XLA_FLAGS="--xla_gpu_force_conv_nhwc" |
2728

2829
### 2. Inference Command
2930

example/t5/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ For benchmarking, you could skip this step because our model script will downloa
3737
git clone https://github.com/google-research/t5x.git
3838
bash install_xpu.sh
3939
pip install --upgrade intel-extension-for-openxla
40-
pip install jax==0.4.13 jaxlib==0.4.13
40+
pip install jax==0.4.20 jaxlib==0.4.20
4141
```
4242
## Inference
4343

example/t5/install_xpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pip uninstall tensorflow -y
1313
pip install tensorflow==2.12.0
1414

1515
conda install libstdcxx-ng==12.2.0 -c conda-forge -y
16-
pip install jax==0.4.13 jaxlib==0.4.13
16+
pip install jax==0.4.20 jaxlib==0.4.20
1717

1818
pip uninstall mdit-py-plugins jupytext -y
1919
pip install t5

example/t5/patch/version_time_dlpath.patch

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ index 37238ba..19bec2f 100644
88

99
-_jax_version = '0.4.11'
1010
-_jaxlib_version = '0.4.11'
11-
+_jax_version = '0.4.13'
12-
+_jaxlib_version = '0.4.13'
11+
+_jax_version = '0.4.20'
12+
+_jaxlib_version = '0.4.20'
1313

1414
setuptools.setup(
1515
name='t5x',

0 commit comments

Comments
 (0)