Skip to content

Commit ce76d05

Browse files
Merge pull request #33372 from jakevdp:jax-nn-docs
PiperOrigin-RevId: 833482568
2 parents 10fe299 + 11f796c commit ce76d05

File tree

4 files changed

+7
-2
lines changed

4 files changed

+7
-2
lines changed

docs/jax.nn.initializers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
2626
glorot_uniform
2727
he_normal
2828
he_uniform
29+
kaiming_normal
30+
kaiming_uniform
2931
lecun_normal
3032
lecun_uniform
3133
normal
@@ -34,4 +36,7 @@ data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
3436
truncated_normal
3537
uniform
3638
variance_scaling
39+
xavier_normal
40+
xavier_uniform
3741
zeros
42+
Initializer

docs/jax.nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Activation functions
3333
hard_silu
3434
hard_swish
3535
hard_tanh
36+
tanh
3637
elu
3738
celu
3839
selu

jax/_src/nn/initializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
@export
5050
@typing.runtime_checkable
5151
class Initializer(Protocol):
52+
"""Protocol for initializers returned by :mod:`jax.nn.initializers` APIs."""
5253
def __call__(self,
5354
key: Array,
5455
shape: core.Shape,

tests/documentation_coverage_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ def jax_docs_dir() -> str:
6565
'jax.lax.linalg': [api for api in dir(jax.lax.linalg) if api.endswith('_p')],
6666
'jax.memory': ['Space'],
6767
'jax.monitoring': ['clear_event_listeners', 'record_event', 'record_event_duration_secs', 'record_event_time_span', 'record_scalar', 'register_event_duration_secs_listener', 'register_event_listener', 'register_event_time_span_listener', 'register_scalar_listener', 'unregister_event_duration_listener', 'unregister_event_listener', 'unregister_event_time_span_listener', 'unregister_scalar_listener'],
68-
'jax.nn': ['tanh'],
69-
'jax.nn.initializers': ['Initializer', 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform'],
7068
'jax.numpy': ['bfloat16', 'bool', 'e', 'euler_gamma', 'float4_e2m1fn', 'float8_e3m4', 'float8_e4m3', 'float8_e4m3b11fnuz', 'float8_e4m3fn', 'float8_e4m3fnuz', 'float8_e5m2', 'float8_e5m2fnuz', 'float8_e8m0fnu', 'inf', 'int2', 'int4', 'nan', 'newaxis', 'pi', 'uint2', 'uint4'],
7169
'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'],
7270
'jax.random': ['key_impl', 'random_gamma_p'],

0 commit comments

Comments
 (0)