Skip to content

Commit 502b581

Browse files
committed
jax.typeof: improve documentation & type annotation
1 parent aced11d commit 502b581

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

docs/jax.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ Miscellaneous
262262
print_environment_info
263263
live_arrays
264264
clear_caches
265+
typeof
265266

266267
Checkpoint policies
267268
-------------------

jax/_src/core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1808,7 +1808,14 @@ def get_aval(x: Any) -> Any:
18081808
)
18091809
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
18101810

1811-
typeof = get_aval
1811+
1812+
# TODO(phawkins): the return type should be AbstractValue.
1813+
def typeof(x: Any, /) -> Any:
1814+
"""Return the JAX type (i.e. :class:`AbstractValue`) of the input.
1815+
1816+
Raises a ``TypeError`` if ``x`` is not a valid JAX type.
1817+
"""
1818+
return get_aval(x)
18121819

18131820
def is_concrete(x):
18141821
return to_concrete_value(x) is not None

tests/documentation_coverage_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def jax_docs_dir() -> str:
5151

5252

5353
UNDOCUMENTED_APIS = {
54-
'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'typeof', 'version'],
54+
'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'version'],
5555
'jax.custom_batching': ['custom_vmap', 'sequential_vmap'],
5656
'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'],
5757
'jax.custom_transpose': ['custom_transpose'],

0 commit comments

Comments
 (0)