Skip to content

Commit 244338e

Browse files
author
Ashiq Imran
committed
Merge branch 'qiming/main-v0.4.30' into r0.5
2 parents b0c14e3 + cbc2519 commit 244338e

33 files changed

+1523
-1688
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Verified Hardware Platforms:
3737
* Ubuntu 22.04, SUSE Linux Enterprise Server(SLES) 15 SP4
3838
* Intel® Data Center GPU Max Series
3939
* [Intel® oneAPI Base Toolkit 2024.2](https://www.intel.com/content/www/us/en/developer/articles/release-notes/intel-oneapi-toolkit-release-notes.html)
40-
* Jax/Jaxlib 0.4.26
40+
* Jax/Jaxlib 0.4.30
4141
* Python 3.9-3.12
4242
* pip 19.0 or later (requires manylinux2014 support)
4343

WORKSPACE

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ http_archive(
1414
name = "xla",
1515
patch_args = ["-p1"],
1616
patches = ["//third_party:openxla.patch"],
17-
sha256 = "fa6e7d17acc362b56c57c43224e6e3eca8569adae864e2fa191cc9d13edf4309",
18-
strip_prefix = "xla-4e8e23f16bc925b6f27817de098a8e1e81296bb5",
17+
sha256 = "083c7281a629647ab2cc32f054afec74893c33e75328783b8085c818f48235ff",
18+
strip_prefix = "xla-79fd5733f99b3c0948d7202bc1bbe1ee3980da5c",
1919
urls = [
20-
"https://github.com/openxla/xla/archive/4e8e23f16bc925b6f27817de098a8e1e81296bb5.tar.gz",
20+
"https://github.com/openxla/xla/archive/79fd5733f99b3c0948d7202bc1bbe1ee3980da5c.tar.gz",
2121
],
2222
)
2323

@@ -33,6 +33,35 @@ http_archive(
3333
# path = "/path/to/xla",
3434
# )
3535

36+
# Initialize hermetic Python
37+
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
38+
39+
python_init_rules()
40+
41+
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
42+
43+
python_init_repositories(
44+
default_python_version = "system",
45+
requirements = {
46+
"3.9": "@xla//:requirements_lock_3_9.txt",
47+
"3.10": "@xla//:requirements_lock_3_10.txt",
48+
"3.11": "@xla//:requirements_lock_3_11.txt",
49+
"3.12": "@xla//:requirements_lock_3_12.txt",
50+
},
51+
)
52+
53+
load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
54+
55+
python_init_toolchains()
56+
57+
load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip")
58+
59+
python_init_pip()
60+
61+
load("@pypi//:requirements.bzl", "install_deps")
62+
63+
install_deps()
64+
3665
load("@xla//:workspace4.bzl", "xla_workspace4")
3766

3867
xla_workspace4()

example/sdxl/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ please got the [main page](https://github.com/intel/intel-extension-for-openxla/
1919
Mark `intel-extension-for-openxla` folder as \<WORKSPACE\>, then
2020
```bash
2121
cd <WORKSPACE>/example/sdxl/
22-
pip install transformers==4.38 diffusers==0.26.3 datasets==2.20.0 msgpack==1.0.7
22+
pip install transformers==4.47 diffusers==0.31.0 datasets==2.20.0 msgpack==1.1.0
2323
pip install -r ../../test/requirements.txt
2424
```
2525

example/t5/install_xpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ git apply ../patch/t5.patch
88
ln -s /usr/local/bin/pip /usr/bin/pip
99
pip uninstall tensorflow-metadata numba cudf -y
1010
pip uninstall tensorflow -y
11-
pip install tensorflow==2.12.0
11+
pip install tensorflow==2.18.0
1212

1313
conda install libstdcxx-ng==12.2.0 -c conda-forge -y
1414

example/t5/patch/t5.patch

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/setup.py b/setup.py
2-
index 37238ba..5ee7b8a 100644
2+
index 37238ba..6a97d34 100644
33
--- a/setup.py
44
+++ b/setup.py
55
@@ -27,8 +27,8 @@ from version import __version__ # pylint: disable=g-import-not-at-top
@@ -8,8 +8,8 @@ index 37238ba..5ee7b8a 100644
88

99
-_jax_version = '0.4.11'
1010
-_jaxlib_version = '0.4.11'
11-
+_jax_version = '0.4.26'
12-
+_jaxlib_version = '0.4.26'
11+
+_jax_version = '0.4.30'
12+
+_jaxlib_version = '0.4.30'
1313

1414
setuptools.setup(
1515
name='t5x',
@@ -19,11 +19,32 @@ index 37238ba..5ee7b8a 100644
1919
'cached_property',
2020
- 'clu @ git+https://github.com/google/CommonLoopUtils#egg=clu',
2121
- 'flax @ git+https://github.com/google/flax#egg=flax',
22-
+ 'clu == 0.0.9',
23-
+ 'flax >= 0.8.2',
22+
+ 'clu == 0.0.12',
23+
+ 'flax >= 0.8.5',
2424
'fiddle >= 0.2.5',
2525
'gin-config',
2626
f'jax >= {_jax_version}',
27+
@@ -61,7 +61,7 @@ setuptools.setup(
28+
'numpy',
29+
'optax @ git+https://github.com/deepmind/optax#egg=optax',
30+
'orbax-checkpoint',
31+
- 'seqio @ git+https://github.com/google/seqio#egg=seqio',
32+
+ 'seqio >= 0.0.18',
33+
'tensorflow-cpu',
34+
'tensorstore >= 0.1.20',
35+
# remove this when sentencepiece_model_pb2 is re-generated in the
36+
diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py
37+
index c8af7d0..4945b2c 100644
38+
--- a/t5x/checkpoints.py
39+
+++ b/t5x/checkpoints.py
40+
@@ -45,7 +45,6 @@ from flax import serialization
41+
from flax import traverse_util
42+
import jax
43+
from jax import monitoring
44+
-import jax.config
45+
from jax.experimental import multihost_utils
46+
from jax.experimental.array_serialization import serialization as array_serialization
47+
import jax.numpy as jnp
2748
diff --git a/t5x/config_utils.py b/t5x/config_utils.py
2849
index abd3f8f..e6e1bd9 100644
2950
--- a/t5x/config_utils.py

example/t5/quick_start.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ echo $MODEL_PATH
6565
echo "Please make sure ${NUM_GPUS} is the number of visible CUDA devices you have"
6666

6767
# Setting XLA flags
68-
export XLA_FLAGS="--xla_gpu_simplify_all_fp_conversions --xla_gpu_all_reduce_combine_threshold_bytes=136314880 ${XLA_FLAGS}"
68+
export XLA_FLAGS="--xla_allow_excess_precision --xla_gpu_all_reduce_combine_threshold_bytes=136314880 ${XLA_FLAGS}"
6969

7070

7171
PREFIX=""

test/BRANCH_NAME

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
jax-v0.4.26
1+
jax-v0.4.30

test/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
jax==0.4.26
2-
jaxlib==0.4.26
3-
flax==0.8.2
1+
jax==0.4.30
2+
jaxlib==0.4.30
3+
flax==0.8.5

third_party/onednn/onednn_gpu.BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ _CMAKE_COMMON_LIST = {
3535
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
3636
"#cmakedefine DNNL_EXPERIMENTAL": "#define DNNL_EXPERIMENTAL",
3737
"#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
38+
"#cmakedefine DNNL_SYCL_GENERIC": "#define DNNL_SYCL_GENERIC",
39+
"#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_INTEL",
40+
"#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS",
41+
"#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING",
42+
"#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 1",
3843
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
3944
"#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
4045
"#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
@@ -103,6 +108,12 @@ gen_onednn_version(
103108
header_out = "include/oneapi/dnnl/dnnl_version.h",
104109
)
105110

111+
gen_onednn_version(
112+
name = "dnnl_version_hash_h",
113+
header_in = "include/oneapi/dnnl/dnnl_version_hash.h.in",
114+
header_out = "include/oneapi/dnnl/dnnl_version_hash.h",
115+
)
116+
106117
filegroup(
107118
name = "onednn_src",
108119
srcs = glob(
@@ -122,6 +133,7 @@ filegroup(
122133
],
123134
) + [
124135
":dnnl_config_h",
136+
":dnnl_version_hash_h",
125137
":header_generator",
126138
":kernel_list_generator",
127139
":onednn_version_generator",
@@ -156,6 +168,9 @@ cc_library(
156168
"include/oneapi/dnnl",
157169
"src",
158170
"src/common",
171+
"src/gpu/intel/jit/gemm/",
172+
"src/gpu/intel/jit/gemm/include/",
173+
"src/gpu/intel/jit/ngen/",
159174
"src/intel/ocl",
160175
"src/sycl",
161176
],

0 commit comments

Comments
 (0)