Skip to content

Commit 3a26804

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
1 parent c6b164d commit 3a26804

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
Device = _xc.Device
8080
del _xc
8181

82-
from jax._src.core import get_ty as get_ty
82+
from jax._src.core import typeof as typeof
8383
from jax._src.api import effects_barrier as effects_barrier
8484
from jax._src.api import block_until_ready as block_until_ready
8585
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401

jax/_src/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1576,7 +1576,7 @@ def get_aval(x):
15761576
return get_aval(x.__jax_array__())
15771577
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
15781578

1579-
get_ty = get_aval
1579+
typeof = get_aval
15801580

15811581
def is_concrete(x):
15821582
return to_concrete_value(x) is not None

tests/mutable_array_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def test_explicit_sharding_after_indexing(self):
216216

217217
@jax.jit
218218
def f(x_ref):
219-
self.assertEqual(core.get_ty(x_ref).sharding.spec,
220-
core.get_ty(x_ref[...]).sharding.spec)
219+
self.assertEqual(core.typeof(x_ref).sharding.spec,
220+
core.typeof(x_ref[...]).sharding.spec)
221221
y = x_ref[...] + 1
222222
return y
223223

tests/pjit_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4883,11 +4883,11 @@ def test_basic_mul(self, mesh):
48834883
arr = jax.device_put(np_inp, s)
48844884

48854885
def f(x):
4886-
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
4886+
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
48874887
x = x * 2
4888-
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
4888+
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
48894889
x = x * x
4890-
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
4890+
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
48914891
return x
48924892

48934893
# Eager mode

0 commit comments

Comments
 (0)