Skip to content

Commit c1ae13b

Browse files
Merge pull request jax-ml#25009 from jakevdp:keyarray
PiperOrigin-RevId: 698865655
2 parents 1efef6b + fee272e commit c1ae13b

File tree

3 files changed

+65
-68
lines changed

3 files changed

+65
-68
lines changed

jax/_src/blocked_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Shape = random.Shape
2424

2525
class SampleFn(Protocol):
26-
def __call__(self, key: random.KeyArrayLike, *args, shape: Shape,
26+
def __call__(self, key: ArrayLike, *args, shape: Shape,
2727
**kwargs) -> Array:
2828
...
2929

@@ -43,7 +43,7 @@ def _compute_scalar_index(iteration_index: Sequence[int],
4343

4444

4545
def blocked_fold_in(
46-
global_key: random.KeyArrayLike,
46+
global_key: ArrayLike,
4747
total_size: Shape,
4848
block_size: Shape,
4949
tile_size: Shape,

jax/_src/nn/initializers.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
export = set_module('jax.nn.initializers')
3838

39-
KeyArray = Array
4039
# TODO: Import or define these to match
4140
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
4241
DTypeLikeFloat = Any
@@ -48,13 +47,13 @@
4847
@typing.runtime_checkable
4948
class Initializer(Protocol):
5049
@staticmethod
51-
def __call__(key: KeyArray,
50+
def __call__(key: Array,
5251
shape: core.Shape,
5352
dtype: DTypeLikeInexact = jnp.float_) -> Array:
5453
raise NotImplementedError
5554

5655
@export
57-
def zeros(key: KeyArray,
56+
def zeros(key: Array,
5857
shape: core.Shape,
5958
dtype: DTypeLikeInexact = jnp.float_) -> Array:
6059
"""An initializer that returns a constant array full of zeros.
@@ -69,7 +68,7 @@ def zeros(key: KeyArray,
6968
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
7069

7170
@export
72-
def ones(key: KeyArray,
71+
def ones(key: Array,
7372
shape: core.Shape,
7473
dtype: DTypeLikeInexact = jnp.float_) -> Array:
7574
"""An initializer that returns a constant array full of ones.
@@ -100,7 +99,7 @@ def constant(value: ArrayLike,
10099
Array([[-7., -7., -7.],
101100
[-7., -7., -7.]], dtype=float32)
102101
"""
103-
def init(key: KeyArray,
102+
def init(key: Array,
104103
shape: core.Shape,
105104
dtype: DTypeLikeInexact = dtype) -> Array:
106105
dtype = dtypes.canonicalize_dtype(dtype)
@@ -126,7 +125,7 @@ def uniform(scale: RealNumeric = 1e-2,
126125
Array([[7.298188 , 8.691938 , 8.7230015],
127126
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
128127
"""
129-
def init(key: KeyArray,
128+
def init(key: Array,
130129
shape: core.Shape,
131130
dtype: DTypeLikeInexact = dtype) -> Array:
132131
dtype = dtypes.canonicalize_dtype(dtype)
@@ -152,7 +151,7 @@ def normal(stddev: RealNumeric = 1e-2,
152151
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
153152
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
154153
"""
155-
def init(key: KeyArray,
154+
def init(key: Array,
156155
shape: core.Shape,
157156
dtype: DTypeLikeInexact = dtype) -> Array:
158157
dtype = dtypes.canonicalize_dtype(dtype)
@@ -189,7 +188,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
189188
[-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
190189
"""
191190

192-
def init(key: KeyArray,
191+
def init(key: Array,
193192
shape: core.Shape,
194193
dtype: DTypeLikeInexact = dtype) -> Array:
195194
dtype = dtypes.canonicalize_dtype(dtype)
@@ -230,7 +229,7 @@ def _compute_fans(shape: Sequence[int],
230229
fan_out = out_size * receptive_field_size
231230
return fan_in, fan_out
232231

233-
def _complex_uniform(key: KeyArray,
232+
def _complex_uniform(key: Array,
234233
shape: Sequence[int],
235234
dtype: DTypeLikeInexact) -> Array:
236235
"""
@@ -244,7 +243,7 @@ def _complex_uniform(key: KeyArray,
244243
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
245244
return r * jnp.exp(1j * theta)
246245

247-
def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
246+
def _complex_truncated_normal(key: Array, upper: ArrayLike,
248247
shape: Sequence[int],
249248
dtype: DTypeLikeInexact) -> Array:
250249
"""
@@ -314,7 +313,7 @@ def variance_scaling(
314313
dtype: the dtype of the weights.
315314
"""
316315

317-
def init(key: KeyArray,
316+
def init(key: Array,
318317
shape: core.Shape,
319318
dtype: DTypeLikeInexact = dtype) -> Array:
320319
shape = core.canonicalize_shape(shape)
@@ -599,7 +598,7 @@ def orthogonal(scale: RealNumeric = 1.0,
599598
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
600599
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
601600
"""
602-
def init(key: KeyArray,
601+
def init(key: Array,
603602
shape: core.Shape,
604603
dtype: DTypeLikeInexact = dtype) -> Array:
605604
dtype = dtypes.canonicalize_dtype(dtype)
@@ -654,7 +653,7 @@ def delta_orthogonal(
654653
655654
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
656655
"""
657-
def init(key: KeyArray,
656+
def init(key: Array,
658657
shape: core.Shape,
659658
dtype: DTypeLikeInexact = dtype) -> Array:
660659
dtype = dtypes.canonicalize_dtype(dtype)

0 commit comments

Comments
 (0)