Skip to content

Commit ba626fa

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Bump JAX version after release.
PiperOrigin-RevId: 703472753
1 parent 9fc077a commit ba626fa

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
1010
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111
-->
1212

13-
## jax 0.4.36
13+
## jax 0.4.37
14+
15+
## jax 0.4.36 (Dec 5, 2024)
1416

1517
* Breaking Changes
1618
* This release lands "stackless", an internal change to JAX's tracing

jax/_src/tree_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,13 @@ def register_pytree_node(
291291
"""
292292
if xla_extension_version >= 299:
293293
default_registry.register_node( # type: ignore[call-arg]
294-
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
294+
nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type]
295295
)
296296
none_leaf_registry.register_node( # type: ignore[call-arg]
297-
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
297+
nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type]
298298
)
299299
dispatch_registry.register_node( # type: ignore[call-arg]
300-
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
300+
nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type]
301301
)
302302
else:
303303
default_registry.register_node(nodetype, flatten_func, unflatten_func)

jax/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pathlib
2222
import subprocess
2323

24-
_version = "0.4.36"
24+
_version = "0.4.37"
2525
# The following line is overwritten by build scripts in distributions &
2626
# releases. Do not modify this manually, or jax/jaxlib build will fail.
2727
_release_version: str | None = None

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
_current_jaxlib_version = '0.4.36'
2323
# The following should be updated after each new jaxlib release.
24-
_latest_jaxlib_version_on_pypi = '0.4.35'
24+
_latest_jaxlib_version_on_pypi = '0.4.36'
2525

2626
_libtpu_version = '0.0.5'
2727
_libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup'

0 commit comments

Comments
 (0)