Skip to content

Commit 8f4e13f

Browse files
Merge pull request #25550 from hawkinsp:postrelease
PiperOrigin-RevId: 707276891
2 parents 7e96914 + ff52aed commit 8f4e13f

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
1010
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111
-->
1212

13-
## jax 0.4.38
13+
## Unreleased
14+
15+
## jax 0.4.38 (Dec 17, 2024)
1416

1517
* Changes:
1618
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added

jax/_src/tree_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def flatten_one_level_with_keys(
613613
tree: Any,
614614
) -> tuple[Iterable[KeyLeafPair], Hashable]:
615615
"""Flatten the given pytree node by one level, with keys."""
616-
out = default_registry.flatten_one_level_with_keys(tree)
616+
out = default_registry.flatten_one_level_with_keys(tree) # type: ignore
617617
if out is None:
618618
raise ValueError(f"can't tree-flatten type: {type(tree)}")
619619
else:

jax/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pathlib
2222
import subprocess
2323

24-
_version = "0.4.38"
24+
_version = "0.4.39"
2525
# The following line is overwritten by build scripts in distributions &
2626
# releases. Do not modify this manually, or jax/jaxlib build will fail.
2727
_release_version: str | None = None
@@ -137,7 +137,7 @@ def make_release_tree(self, base_dir, files):
137137

138138

139139
__version__ = _get_version_string()
140-
_minimum_jaxlib_version = "0.4.36"
140+
_minimum_jaxlib_version = "0.4.38"
141141

142142
def _version_as_tuple(version_str):
143143
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020
project_name = 'jax'
2121

22-
_current_jaxlib_version = '0.4.36'
22+
_current_jaxlib_version = '0.4.38'
2323
# The following should be updated after each new jaxlib release.
24-
_latest_jaxlib_version_on_pypi = '0.4.36'
24+
_latest_jaxlib_version_on_pypi = '0.4.38'
2525

26-
_libtpu_version = '0.0.6'
26+
_libtpu_version = '0.0.7'
2727
_libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup'
2828

2929
def load_version_module(pkg_path):

0 commit comments

Comments
 (0)