Skip to content

Commit 5cda053

Browse files
authored
Merge pull request #185 from ROCm/ci-upstream-sync-12-12-2024
CI: 12/12/24 upstream sync
2 parents 6dc4dee + 02831ed commit 5cda053

File tree

103 files changed

+2463
-878
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+2463
-878
lines changed

.bazelrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ build:avx_windows --copt=/arch:AVX
9696

9797
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
9898

99+
# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL).
100+
build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true
101+
build:mkl_aarch64_threadpool --@compute_library//:openmp=false
102+
build:mkl_aarch64_threadpool -c opt
103+
99104
# Disable clang extention that rejects type definitions within offsetof.
100105
# This was added in clang-16 by https://reviews.llvm.org/D133574.
101106
# Can be removed once upb is updated, since a type definition is used within

.github/workflows/bazel_cpu_rbe.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ on:
1111
options:
1212
- 'yes'
1313
- 'no'
14+
pull_request:
15+
branches:
16+
- main
1417

1518
concurrency:
1619
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
@@ -21,15 +24,18 @@ jobs:
2124
if: github.event.repository.fork == false
2225
strategy:
2326
matrix:
24-
runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"]
27+
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
28+
enable-x_64: [1, 0]
2529

2630
runs-on: ${{ matrix.runner }}
27-
# TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available
2831
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
29-
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }}
32+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
3033

3134
env:
3235
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
36+
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
37+
38+
name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
3339

3440
steps:
3541
- uses: actions/checkout@v3

.github/workflows/bazel_gpu_rbe.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ on:
1111
options:
1212
- 'yes'
1313
- 'no'
14+
pull_request:
15+
branches:
16+
- main
1417

1518
concurrency:
1619
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
@@ -22,12 +25,16 @@ jobs:
2225
strategy:
2326
matrix:
2427
runner: ["linux-x86-n2-16"]
28+
enable-x_64: [1, 0]
2529

2630
runs-on: ${{ matrix.runner }}
2731
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
2832

2933
env:
3034
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
35+
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
36+
37+
name: "Bazel single accelerator GPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
3138

3239
steps:
3340
- uses: actions/checkout@v3

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ repos:
3636
- id: mypy
3737
files: (jax/|tests/typing_test\.py)
3838
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
39-
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy~=2.1.0]
39+
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0]
4040
args: [--config=pyproject.toml]
4141

4242
- repo: https://github.com/mwouts/jupytext

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1919
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
2020
{mod}`jax.extend` for information on the compatibility guarantees of these
2121
semi-public extensions.
22+
* Several previously-deprecated APIs have been removed, including:
23+
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
24+
`non_negative_dim`.
25+
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
26+
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
2227

2328
## jax 0.4.37 (Dec 9, 2024)
2429

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ are instances of such transformations. Others are
4747
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
4848
parallel programming of multiple accelerators, with more to come.
4949

50-
This is a research project, not an official Google product. Expect bugs and
50+
This is a research project, not an official Google product. Expect
5151
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
5252
Please help by trying it out, [reporting
5353
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you

build/build.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,10 @@ async def main():
485485

486486
if not args.disable_mkl_dnn:
487487
logging.debug("Enabling MKL DNN")
488-
wheel_build_command.append("--config=mkl_open_source_only")
488+
if target_cpu == "aarch64":
489+
wheel_build_command.append("--config=mkl_aarch64_threadpool")
490+
else:
491+
wheel_build_command.append("--config=mkl_open_source_only")
489492

490493
if args.target_cpu_features == "release":
491494
if arch in ["x86_64", "AMD64"]:

docs/jax.lib.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ jax.lib.xla_bridge
1111
.. autosummary::
1212
:toctree: _autosummary
1313

14-
default_backend
1514
get_backend
1615
get_compile_options
1716

docs/jax.numpy.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ namespace; they are listed below.
274274
mask_indices
275275
matmul
276276
matrix_transpose
277+
matvec
277278
max
278279
maximum
279280
mean
@@ -428,6 +429,7 @@ namespace; they are listed below.
428429
var
429430
vdot
430431
vecdot
432+
vecmat
431433
vectorize
432434
vsplit
433435
vstack

examples/ffi/tests/cpu_examples_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def test_array_attr_jit_cache(self):
3737
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
3838
with jtu.count_jit_and_pmap_lowerings() as count:
3939
jit_array_attr(5)
40-
self.assertEqual(count[0], 1) # compiles once the first time
40+
self.assertEqual(count(), 1) # compiles once the first time
4141
with jtu.count_jit_and_pmap_lowerings() as count:
4242
jit_array_attr(5)
43-
self.assertEqual(count[0], 0) # cache hit
43+
self.assertEqual(count(), 0) # cache hit
4444

4545
def test_array_attr_no_jit(self):
4646
with jax.disable_jit():

0 commit comments

Comments
 (0)