3636
3737export = set_module ('jax.nn.initializers' )
3838
39- KeyArray = Array
4039# TODO: Import or define these to match
4140# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
4241DTypeLikeFloat = Any
4847@typing .runtime_checkable
4948class Initializer (Protocol ):
5049 @staticmethod
51- def __call__ (key : KeyArray ,
50+ def __call__ (key : Array ,
5251 shape : core .Shape ,
5352 dtype : DTypeLikeInexact = jnp .float_ ) -> Array :
5453 raise NotImplementedError
5554
5655@export
57- def zeros (key : KeyArray ,
56+ def zeros (key : Array ,
5857 shape : core .Shape ,
5958 dtype : DTypeLikeInexact = jnp .float_ ) -> Array :
6059 """An initializer that returns a constant array full of zeros.
@@ -69,7 +68,7 @@ def zeros(key: KeyArray,
6968 return jnp .zeros (shape , dtypes .canonicalize_dtype (dtype ))
7069
7170@export
72- def ones (key : KeyArray ,
71+ def ones (key : Array ,
7372 shape : core .Shape ,
7473 dtype : DTypeLikeInexact = jnp .float_ ) -> Array :
7574 """An initializer that returns a constant array full of ones.
@@ -100,7 +99,7 @@ def constant(value: ArrayLike,
10099 Array([[-7., -7., -7.],
101100 [-7., -7., -7.]], dtype=float32)
102101 """
103- def init (key : KeyArray ,
102+ def init (key : Array ,
104103 shape : core .Shape ,
105104 dtype : DTypeLikeInexact = dtype ) -> Array :
106105 dtype = dtypes .canonicalize_dtype (dtype )
@@ -126,7 +125,7 @@ def uniform(scale: RealNumeric = 1e-2,
126125 Array([[7.298188 , 8.691938 , 8.7230015],
127126 [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
128127 """
129- def init (key : KeyArray ,
128+ def init (key : Array ,
130129 shape : core .Shape ,
131130 dtype : DTypeLikeInexact = dtype ) -> Array :
132131 dtype = dtypes .canonicalize_dtype (dtype )
@@ -152,7 +151,7 @@ def normal(stddev: RealNumeric = 1e-2,
152151 Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
153152 [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
154153 """
155- def init (key : KeyArray ,
154+ def init (key : Array ,
156155 shape : core .Shape ,
157156 dtype : DTypeLikeInexact = dtype ) -> Array :
158157 dtype = dtypes .canonicalize_dtype (dtype )
@@ -189,7 +188,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
189188 [-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
190189 """
191190
192- def init (key : KeyArray ,
191+ def init (key : Array ,
193192 shape : core .Shape ,
194193 dtype : DTypeLikeInexact = dtype ) -> Array :
195194 dtype = dtypes .canonicalize_dtype (dtype )
@@ -230,7 +229,7 @@ def _compute_fans(shape: Sequence[int],
230229 fan_out = out_size * receptive_field_size
231230 return fan_in , fan_out
232231
233- def _complex_uniform (key : KeyArray ,
232+ def _complex_uniform (key : Array ,
234233 shape : Sequence [int ],
235234 dtype : DTypeLikeInexact ) -> Array :
236235 """
@@ -244,7 +243,7 @@ def _complex_uniform(key: KeyArray,
244243 theta = 2 * jnp .pi * random .uniform (key_theta , shape , real_dtype ).astype (dtype )
245244 return r * jnp .exp (1j * theta )
246245
247- def _complex_truncated_normal (key : KeyArray , upper : ArrayLike ,
246+ def _complex_truncated_normal (key : Array , upper : ArrayLike ,
248247 shape : Sequence [int ],
249248 dtype : DTypeLikeInexact ) -> Array :
250249 """
@@ -314,7 +313,7 @@ def variance_scaling(
314313 dtype: the dtype of the weights.
315314 """
316315
317- def init (key : KeyArray ,
316+ def init (key : Array ,
318317 shape : core .Shape ,
319318 dtype : DTypeLikeInexact = dtype ) -> Array :
320319 shape = core .canonicalize_shape (shape )
@@ -599,7 +598,7 @@ def orthogonal(scale: RealNumeric = 1.0,
599598 Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
600599 [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
601600 """
602- def init (key : KeyArray ,
601+ def init (key : Array ,
603602 shape : core .Shape ,
604603 dtype : DTypeLikeInexact = dtype ) -> Array :
605604 dtype = dtypes .canonicalize_dtype (dtype )
@@ -654,7 +653,7 @@ def delta_orthogonal(
654653
655654 .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
656655 """
657- def init (key : KeyArray ,
656+ def init (key : Array ,
658657 shape : core .Shape ,
659658 dtype : DTypeLikeInexact = dtype ) -> Array :
660659 dtype = dtypes .canonicalize_dtype (dtype )
0 commit comments