Skip to content

Commit 8ed59d8

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
Removed jax._src.raise_to_shaped
It is just an identity after the "stackless" rewrite. PiperOrigin-RevId: 745042532
1 parent c2eaedf commit 8ed59d8

File tree

5 files changed

+5
-18
lines changed

5 files changed

+5
-18
lines changed

jax/_src/core.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,11 +2229,6 @@ def block_until_ready(self):
22292229
pytype_aval_mappings[Token] = lambda _: abstract_token
22302230

22312231

2232-
# TODO(dougalm): Deprecate these. They're just here for backwards compat.
2233-
def raise_to_shaped(aval):
2234-
return aval
2235-
raise_to_shaped_mappings: dict[type, Callable] = {}
2236-
22372232
### Operations on shapes and dimension sizes.
22382233

22392234
class InconclusiveDimensionOperation(Exception):

jax/_src/pallas/fuser/fusable.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@
2929
fusable_p.multiple_results = True
3030

3131

32-
def _get_aval(x):
33-
return jax_core.raise_to_shaped(jax_core.get_aval(x))
34-
35-
3632
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
3733
return fusion_lib.Fusion(
3834
func=lambda: x,
@@ -53,7 +49,7 @@ def wrapped(*args):
5349
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
5450
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
5551
)
56-
flat_avals = [_get_aval(x) for x in flat_args]
52+
flat_avals = [jax_core.get_aval(x) for x in flat_args]
5753
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
5854
out_tree = out_tree_thunk()
5955
out = fusable_p.bind(

jax/_src/pallas/fuser/jaxpr_fusion.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
from jax._src.pallas.fuser.fusable import fusable_p
2929

3030

31-
def _get_aval(x):
32-
return jax_core.raise_to_shaped(jax_core.get_aval(x))
33-
34-
3531
def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
3632
"""Fuses a function into a single fusable.
3733
@@ -52,7 +48,7 @@ def wrapper(*args, **kwargs):
5248
flat_fun, out_tree_thunk = api_util.flatten_fun(
5349
lu.wrap_init(f, debug_info=debug_info), in_tree
5450
)
55-
flat_avals = [_get_aval(x) for x in flat_args]
51+
flat_avals = [jax_core.get_aval(x) for x in flat_args]
5652
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
5753
if debug:
5854
print("Jaxpr before fusion:")

tests/api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5024,7 +5024,7 @@ def g(x):
50245024

50255025
# Make sure that introducing constants in vmap works.
50265026
constant_introducing_p = core.Primitive('introduce_constant')
5027-
constant_introducing_p.def_abstract_eval(core.raise_to_shaped)
5027+
constant_introducing_p.def_abstract_eval(lambda x: x)
50285028
def _constant_introducing_batcher(xs, ds):
50295029
(x,), (d,) = xs, ds
50305030
return (x + np.arange(x.size, dtype=x.dtype).reshape(x.shape)), d

tests/state_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def body(i, st):
792792
lax.fori_loop(0, 5, body, init_val=())
793793
return a_ref[...], b_ref[...]
794794

795-
ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
795+
ref = lambda x: AbstractRef(core.get_aval(x))
796796
f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.))
797797
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True])
798798
# Effects on y_ref were discharged away but not the effects on x_ref
@@ -1139,7 +1139,7 @@ def false_fun():
11391139
y_ref[...] = 2.
11401140
lax.cond(pred, true_fun, false_fun)
11411141
return x_ref[...], y_ref[...]
1142-
ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
1142+
ref = lambda x: AbstractRef(core.get_aval(x))
11431143
f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.))
11441144
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True])
11451145
# Effects on y_ref were discharged away but not the effects on x_ref

0 commit comments

Comments
 (0)