Skip to content

Commit 9c32fe8

Browse files
Merge pull request jax-ml#25357 from jakevdp:core-deps
PiperOrigin-RevId: 704808153
2 parents e418e88 + 6541a62 commit 9c32fe8

20 files changed

+84
-40
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1212

1313
## jax 0.4.38
1414

15+
* Deprecations
16+
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
17+
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
18+
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
19+
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
20+
{mod}`jax.extend` for information on the compatibility guarantees of these
21+
semi-public extensions.
22+
1523
## jax 0.4.37 (Dec 9, 2024)
1624

1725
This is a patch release of jax 0.4.36. Only "jax" was released at this version.

docs/contributor_guide.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,3 @@ some of JAX's (extensible) internals.
2525

2626
autodidax
2727
jep/index
28-
jax_internal_api

docs/jax.extend.core.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
``jax.extend.core`` module
2+
==========================
3+
4+
.. automodule:: jax.extend.core
5+
6+
.. autosummary::
7+
:toctree: _autosummary
8+
9+
ClosedJaxpr
10+
Jaxpr
11+
JaxprEqn
12+
Literal
13+
Primitive
14+
Token
15+
Var
16+
array_types
17+
jaxpr_as_fun
18+
primitives

docs/jax.extend.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Modules
1111
.. toctree::
1212
:maxdepth: 1
1313

14+
jax.extend.core
1415
jax.extend.ffi
1516
jax.extend.linear_util
1617
jax.extend.mlir

docs/jax_internal_api.rst

Lines changed: 0 additions & 14 deletions
This file was deleted.

jax/_src/cudnn/fused_attention_stablehlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import math
1919

2020
import jax
21-
from jax import core
2221
from jax import dtypes
22+
from jax._src import core
2323
from jax._src import dispatch
2424
from jax._src.custom_partitioning import custom_partitioning
2525
from jax._src.interpreters import batching

jax/_src/cudnn/fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import functools
1616
import jax
17-
from jax import core as jax_core
17+
from jax._src import core as jax_core
1818
from jax.interpreters import mlir
1919
from jax.interpreters.mlir import hlo
2020
from jax.interpreters.mlir import ir

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from typing import Any
2222

2323
import jax
24-
from jax import core as jax_core
2524
from jax import dtypes
2625
from jax._src import config
27-
from jax._src import core as jax_src_core
26+
from jax._src import core as jax_core
2827
from jax._src import sharding_impls
2928
from jax._src import tpu_custom_call
3029
from jax._src.interpreters import mlir
@@ -189,7 +188,7 @@ def lower_module(for_verification: bool):
189188
# Replace in_avals to physical avals.
190189
# This step is required for mapping logical types to physical types.
191190
# (e.g. PRNG key -> uint32[2])
192-
physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
191+
physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in]
193192
ctx = ctx.replace(avals_in=physical_avals)
194193

195194
# Booleans are loaded into the kernel as integers.

jax/_src/pallas/mosaic_gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pytype_strict_library(
4444
deps = [
4545
":lowering",
4646
"//jax",
47+
"//jax:core",
4748
"//jax:mlir",
4849
"//jax:mosaic_gpu",
4950
"//jax/_src/pallas",

jax/_src/pallas/mosaic_gpu/pallas_call_registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import warnings
2424

2525
import jax
26-
from jax import core as jax_core
26+
from jax._src import core as jax_core
2727
from jax._src.interpreters import mlir
2828
from jax._src.pallas import core as pallas_core
2929
from jax._src.pallas.mosaic_gpu import lowering

0 commit comments

Comments
 (0)