Skip to content

Commit 9afbd23

Browse files
authored
Merge pull request #131 from ROCm/ci-upstream-sync-11_1
CI: 11/06/24 upstream sync
2 parents 350e04d + 587d733 commit 9afbd23

File tree

168 files changed

+3495
-2227
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

168 files changed

+3495
-2227
lines changed

.github/workflows/ci-build.yaml

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ jobs:
3434
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
3535
with:
3636
python-version: 3.11
37-
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
37+
- run: python -m pip install pre-commit
38+
- uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
39+
with:
40+
path: ~/.cache/pre-commit
41+
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
42+
- run: pre-commit run --show-diff-on-failure --color=always
3843

3944
build:
4045
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
@@ -211,14 +216,14 @@ jobs:
211216
212217
ffi:
213218
name: FFI example
214-
runs-on: ubuntu-latest
215-
timeout-minutes: 5
219+
runs-on: ROCM-Ubuntu
220+
timeout-minutes: 30
216221
steps:
217222
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
218-
- name: Set up Python 3.11
223+
- name: Set up Python
219224
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
220225
with:
221-
python-version: 3.11
226+
python-version: 3.12
222227
- name: Get pip cache dir
223228
id: pip-cache
224229
run: |
@@ -239,6 +244,10 @@ jobs:
239244
# a different toolchain. GCC is the default compiler on the
240245
# 'ubuntu-latest' runner, but we still set this explicitly just to be
241246
# clear.
242-
CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++
243-
- name: Run tests
247+
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
248+
- name: Run CPU tests
249+
run: python -m pytest examples/ffi/tests
250+
env:
251+
JAX_PLATFORM_NAME: cpu
252+
- name: Run GPU tests
244253
run: python -m pytest examples/ffi/tests

.github/workflows/jax-array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
31-
ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09
31+
ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,26 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2323
* The deprecated module `jax.experimental.export` has been removed. It was replaced
2424
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
2525
for information on migrating to the new API.
26+
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
27+
has been removed, after being deprecated in v0.4.27.
28+
* The following deprecated methods and functions in {mod}`jax.export` have
29+
been removed:
30+
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
31+
already.
32+
* `jax.export.Exported.lowering_platforms`: use `platforms`.
33+
* `jax.export.Exported.mlir_module_serialization_version`:
34+
use `calling_convention_version`.
35+
* `jax.export.Exported.uses_shape_polymorphism`:
36+
use `uses_global_constants`.
37+
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
38+
`platforms` instead.
39+
* Hashing of tracers, which has been deprecated since version 0.4.30, now
40+
results in a `TypeError`.
41+
42+
* New Features
43+
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
44+
passing compilation options to XLA. For the moment it's undocumented and
45+
may be in flux.
2646

2747
## jax 0.4.35 (Oct 22, 2024)
2848

README.md

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -411,23 +411,18 @@ community-supported conda build, and answers to some frequently-asked questions.
411411

412412
## Neural network libraries
413413

414-
Multiple Google research groups develop and share libraries for training neural
415-
networks in JAX. If you want a fully featured library for neural network
414+
Multiple Google research groups at Google DeepMind and Alphabet develop and share libraries
415+
for training neural networks in JAX. If you want a fully featured library for neural network
416416
training with examples and how-to guides, try
417-
[Flax](https://github.com/google/flax). Check out the new [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) API for a
418-
simplified development experience.
419-
420-
Google X maintains the neural network library
421-
[Equinox](https://github.com/patrick-kidger/equinox). This is used as the
422-
foundation for several other libraries in the JAX ecosystem.
423-
424-
In addition, DeepMind has open-sourced an [ecosystem of libraries around
425-
JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research)
426-
including [Optax](https://github.com/deepmind/optax) for gradient processing and
427-
optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and
428-
[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch
429-
the NeurIPS 2020 JAX Ecosystem at DeepMind talk
430-
[here](https://www.youtube.com/watch?v=iDxJxIyzSiM))
417+
[Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html).
418+
419+
Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem)
420+
on the JAX documentation site for a list of JAX-based network libraries, which includes
421+
[Optax](https://github.com/deepmind/optax) for gradient processing and
422+
optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and
423+
[Equinox](https://github.com/patrick-kidger/equinox) for neural networks.
424+
(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talk
425+
[here](https://www.youtube.com/watch?v=iDxJxIyzSiM) for additional details.)
431426

432427
## Citing JAX
433428

build/requirements.in

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
#
44
-r test-requirements.txt
55

6-
# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement
7-
# below.
8-
matplotlib~=3.8.4; python_version<="3.10"
9-
matplotlib; python_version>="3.11"
10-
116
#
127
# build deps
138
#

0 commit comments

Comments
 (0)