Skip to content

Commit fee272e

Browse files
committed
Remove internal KeyArray alias
This was useful during the transition to typed PRNG keys, but is no longer necessary. It also makes generated HTML docs confusing: it's better to just use Array as we expect users to.
1 parent 8d84f28 commit fee272e

File tree

3 files changed

+65
-68
lines changed

3 files changed

+65
-68
lines changed

jax/_src/blocked_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Shape = random.Shape
2424

2525
class SampleFn(Protocol):
26-
def __call__(self, key: random.KeyArrayLike, *args, shape: Shape,
26+
def __call__(self, key: ArrayLike, *args, shape: Shape,
2727
**kwargs) -> Array:
2828
...
2929

@@ -43,7 +43,7 @@ def _compute_scalar_index(iteration_index: Sequence[int],
4343

4444

4545
def blocked_fold_in(
46-
global_key: random.KeyArrayLike,
46+
global_key: ArrayLike,
4747
total_size: Shape,
4848
block_size: Shape,
4949
tile_size: Shape,

jax/_src/nn/initializers.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
export = set_module('jax.nn.initializers')
3838

39-
KeyArray = Array
4039
# TODO: Import or define these to match
4140
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
4241
DTypeLikeFloat = Any
@@ -48,13 +47,13 @@
4847
@typing.runtime_checkable
4948
class Initializer(Protocol):
5049
@staticmethod
51-
def __call__(key: KeyArray,
50+
def __call__(key: Array,
5251
shape: core.Shape,
5352
dtype: DTypeLikeInexact = jnp.float_) -> Array:
5453
raise NotImplementedError
5554

5655
@export
57-
def zeros(key: KeyArray,
56+
def zeros(key: Array,
5857
shape: core.Shape,
5958
dtype: DTypeLikeInexact = jnp.float_) -> Array:
6059
"""An initializer that returns a constant array full of zeros.
@@ -69,7 +68,7 @@ def zeros(key: KeyArray,
6968
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
7069

7170
@export
72-
def ones(key: KeyArray,
71+
def ones(key: Array,
7372
shape: core.Shape,
7473
dtype: DTypeLikeInexact = jnp.float_) -> Array:
7574
"""An initializer that returns a constant array full of ones.
@@ -100,7 +99,7 @@ def constant(value: ArrayLike,
10099
Array([[-7., -7., -7.],
101100
[-7., -7., -7.]], dtype=float32)
102101
"""
103-
def init(key: KeyArray,
102+
def init(key: Array,
104103
shape: core.Shape,
105104
dtype: DTypeLikeInexact = dtype) -> Array:
106105
dtype = dtypes.canonicalize_dtype(dtype)
@@ -126,7 +125,7 @@ def uniform(scale: RealNumeric = 1e-2,
126125
Array([[7.298188 , 8.691938 , 8.7230015],
127126
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
128127
"""
129-
def init(key: KeyArray,
128+
def init(key: Array,
130129
shape: core.Shape,
131130
dtype: DTypeLikeInexact = dtype) -> Array:
132131
dtype = dtypes.canonicalize_dtype(dtype)
@@ -152,7 +151,7 @@ def normal(stddev: RealNumeric = 1e-2,
152151
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
153152
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
154153
"""
155-
def init(key: KeyArray,
154+
def init(key: Array,
156155
shape: core.Shape,
157156
dtype: DTypeLikeInexact = dtype) -> Array:
158157
dtype = dtypes.canonicalize_dtype(dtype)
@@ -189,7 +188,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
189188
[-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
190189
"""
191190

192-
def init(key: KeyArray,
191+
def init(key: Array,
193192
shape: core.Shape,
194193
dtype: DTypeLikeInexact = dtype) -> Array:
195194
dtype = dtypes.canonicalize_dtype(dtype)
@@ -230,7 +229,7 @@ def _compute_fans(shape: Sequence[int],
230229
fan_out = out_size * receptive_field_size
231230
return fan_in, fan_out
232231

233-
def _complex_uniform(key: KeyArray,
232+
def _complex_uniform(key: Array,
234233
shape: Sequence[int],
235234
dtype: DTypeLikeInexact) -> Array:
236235
"""
@@ -244,7 +243,7 @@ def _complex_uniform(key: KeyArray,
244243
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
245244
return r * jnp.exp(1j * theta)
246245

247-
def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
246+
def _complex_truncated_normal(key: Array, upper: ArrayLike,
248247
shape: Sequence[int],
249248
dtype: DTypeLikeInexact) -> Array:
250249
"""
@@ -314,7 +313,7 @@ def variance_scaling(
314313
dtype: the dtype of the weights.
315314
"""
316315

317-
def init(key: KeyArray,
316+
def init(key: Array,
318317
shape: core.Shape,
319318
dtype: DTypeLikeInexact = dtype) -> Array:
320319
shape = core.canonicalize_shape(shape)
@@ -599,7 +598,7 @@ def orthogonal(scale: RealNumeric = 1.0,
599598
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
600599
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
601600
"""
602-
def init(key: KeyArray,
601+
def init(key: Array,
603602
shape: core.Shape,
604603
dtype: DTypeLikeInexact = dtype) -> Array:
605604
dtype = dtypes.canonicalize_dtype(dtype)
@@ -654,7 +653,7 @@ def delta_orthogonal(
654653
655654
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
656655
"""
657-
def init(key: KeyArray,
656+
def init(key: Array,
658657
shape: core.Shape,
659658
dtype: DTypeLikeInexact = dtype) -> Array:
660659
dtype = dtypes.canonicalize_dtype(dtype)

0 commit comments

Comments
 (0)