Skip to content

Commit 772339e

Browse files
Merge pull request #25508 from jakevdp:dep-core
PiperOrigin-RevId: 706810732
2 parents 8552852 + cfa3884 commit 772339e

File tree

1 file changed

+29
-36
lines changed

1 file changed

+29
-36
lines changed

jax/core.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,22 @@
2020
AbstractValue as AbstractValue,
2121
Atom as Atom,
2222
CallPrimitive as CallPrimitive,
23-
ClosedJaxpr as ClosedJaxpr,
2423
DShapedArray as DShapedArray,
2524
DropVar as DropVar,
2625
Effect as Effect,
2726
Effects as Effects,
2827
get_opaque_trace_state as get_opaque_trace_state,
2928
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
30-
Jaxpr as Jaxpr,
3129
JaxprDebugInfo as JaxprDebugInfo,
32-
JaxprEqn as JaxprEqn,
3330
JaxprPpContext as JaxprPpContext,
3431
JaxprPpSettings as JaxprPpSettings,
3532
JaxprTypeError as JaxprTypeError,
3633
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
37-
Literal as Literal,
3834
OutputType as OutputType,
3935
ParamDict as ParamDict,
40-
Primitive as Primitive,
4136
ShapedArray as ShapedArray,
42-
Token as Token,
4337
Trace as Trace,
4438
Tracer as Tracer,
45-
Var as Var,
4639
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
4740
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
4841
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
@@ -88,28 +81,28 @@
8881

8982
from jax._src import core as _src_core
9083
_deprecations = {
91-
# TODO(jakevdp): re-deprecate these after migrating some downstream uses.
92-
# "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, "
93-
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
94-
# _src_core.ClosedJaxpr),
95-
# "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, "
96-
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
97-
# _src_core.Jaxpr),
98-
# "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, "
99-
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
100-
# _src_core.JaxprEqn),
101-
# "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, "
102-
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
103-
# _src_core.Literal),
104-
# "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, "
105-
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
106-
# _src_core.Primitive),
107-
# "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, "
108-
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
109-
# _src_core.Token),
110-
# "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
111-
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
112-
# _src_core.Var),
84+
# Added 2024-12-16
85+
"ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, "
86+
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
87+
_src_core.ClosedJaxpr),
88+
"Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, "
89+
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
90+
_src_core.Jaxpr),
91+
"JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, "
92+
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
93+
_src_core.JaxprEqn),
94+
"Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, "
95+
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
96+
_src_core.Literal),
97+
"Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, "
98+
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
99+
_src_core.Primitive),
100+
"Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, "
101+
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
102+
_src_core.Token),
103+
"Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
104+
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
105+
_src_core.Var),
113106
# Added 2024-12-11
114107
"axis_frame": ("jax.core.axis_frame is deprecated.", _src_core.axis_frame),
115108
"AxisName": ("jax.core.AxisName is deprecated.", _src_core.AxisName),
@@ -196,21 +189,21 @@
196189
if typing.TYPE_CHECKING:
197190
AxisName = _src_core.AxisName
198191
AxisSize = _src_core.AxisSize
199-
# ClosedJaxpr = _src_core.ClosedJaxpr
192+
ClosedJaxpr = _src_core.ClosedJaxpr
200193
ConcretizationTypeError = _src_core.ConcretizationTypeError
201194
EvalTrace = _src_core.EvalTrace
202195
InDBIdx = _src_core.InDBIdx
203196
InputType = _src_core.InputType
204-
# Jaxpr = _src_core.Jaxpr
205-
# JaxprEqn = _src_core.JaxprEqn
206-
# Literal = _src_core.Literal
197+
Jaxpr = _src_core.Jaxpr
198+
JaxprEqn = _src_core.JaxprEqn
199+
Literal = _src_core.Literal
207200
MapPrimitive = _src_core.MapPrimitive
208201
OpaqueTraceState = _src_core.OpaqueTraceState
209202
OutDBIdx = _src_core.OutDBIdx
210-
# Primitive = _src_core.Primitive
211-
# Token = _src_core.Token
203+
Primitive = _src_core.Primitive
204+
Token = _src_core.Token
212205
TRACER_LEAK_DEBUGGER_WARNING = _src_core.TRACER_LEAK_DEBUGGER_WARNING
213-
# Var = _src_core.Var
206+
Var = _src_core.Var
214207
axis_frame = _src_core.axis_frame
215208
call_p = _src_core.call_p
216209
closed_call_p = _src_core.closed_call_p

0 commit comments

Comments
 (0)