Skip to content

Commit b7af1eb

Browse files
Merge pull request jax-ml#25381 from jakevdp:mypy-np22
PiperOrigin-RevId: 705248189
2 parents e55bbc7 + f4f4bf6 commit b7af1eb

File tree

11 files changed

+14
-14
lines changed

11 files changed

+14
-14
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ repos:
3636
- id: mypy
3737
files: (jax/|tests/typing_test\.py)
3838
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
39-
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy~=2.1.0]
39+
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0]
4040
args: [--config=pyproject.toml]
4141

4242
- repo: https://github.com/mwouts/jupytext

jax/_src/dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
419419
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}
420420

421421
# Otherwise, fall back to numpy.issubdtype
422-
return np.issubdtype(a_sctype, b_sctype)
422+
return bool(np.issubdtype(a_sctype, b_sctype))
423423

424424
can_cast = np.can_cast
425425

jax/_src/interpreters/mlir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2801,7 +2801,7 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined
28012801
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
28022802
if minor_to_major is None:
28032803
# Needed for token layouts
2804-
layout = np.zeros((0,), dtype="int64")
2804+
layout: np.ndarray = np.zeros((0,), dtype="int64")
28052805
else:
28062806
layout = np.array(minor_to_major, dtype="int64")
28072807
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())

jax/_src/mesh_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def _create_device_mesh_for_nd_torus_splitting_axes(
386386
)
387387
):
388388
best_logical_axis_assignment = logical_axis_assignment
389-
assignment[:, logical_axis] = best_logical_axis_assignment
389+
assignment[:, logical_axis] = best_logical_axis_assignment # type: ignore # numpy 2.2
390390

391391
# Read out the assignment.
392392
logical_mesh = _generate_logical_mesh(
@@ -597,10 +597,10 @@ def _generate_logical_mesh(
597597
zip(logical_indices, physical_indices, range(len(logical_indices)))
598598
)
599599
)
600-
logical_mesh = np.transpose(logical_mesh, transpose_axes)
600+
logical_mesh = np.transpose(logical_mesh, transpose_axes) # type: ignore # numpy 2.2
601601

602602
# Reshape to add the trivial dimensions back.
603-
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape)
603+
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) # type: ignore # numpy 2.2
604604

605605
return logical_mesh
606606

jax/_src/numpy/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
13741374
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
13751375
else:
13761376
if rcond is None:
1377-
rcond = jnp.finfo(dtype).eps * max(n, m)
1377+
rcond = float(jnp.finfo(dtype).eps) * max(n, m)
13781378
else:
13791379
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
13801380
u, s, vt = svd(a, full_matrices=False)

jax/_src/numpy/polynomial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
246246

247247
# set rcond
248248
if rcond is None:
249-
rcond = len(x_arr) * finfo(x_arr.dtype).eps
249+
rcond = len(x_arr) * float(finfo(x_arr.dtype).eps)
250250
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
251251
# set up least squares equation for powers of x
252252
lhs = vander(x_arr, order)

jax/_src/op_shardings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def op_sharding_to_numpy_indices(
100100

101101
for i, idxs in enumerate(itertools.product(*axis_indices)):
102102
for _ in range(num_replicas):
103-
indices[next(device_it)] = idxs
103+
indices[next(device_it)] = idxs # type: ignore # numpy 2.2
104104
return indices
105105

106106

jax/_src/sharding_impls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def __repr__(self) -> str:
738738
ids = self._ids.copy()
739739
platform_name = self._devices[0].platform.upper()
740740
for idx, x in np.ndenumerate(ids):
741-
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x))
741+
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) # type: ignore # numpy 2.2
742742
body = np.array2string(ids, prefix=cls_name + '(', suffix=')',
743743
max_line_width=100)
744744
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'

jax/_src/sharding_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
9797
# is used to extract the corresponding shard of the logical array.
9898
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
9999
for i, idxs in enumerate(itertools.product(*axis_indices)):
100-
shard_indices[i] = idxs
100+
shard_indices[i] = idxs # type: ignore # numpy 2.2
101101
shard_indices = shard_indices.reshape(shard_indices_shape)
102102

103103
# Ensure that each sharded axis is used exactly once in the mesh mapping

jax/experimental/sparse/bcoo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int],
578578
raise ValueError("Cannot transpose with respect to sparse indices")
579579
assert data_ct.dtype == data.aval.dtype
580580
ct_spinfo = SparseInfo(tuple(spinfo.shape[p] for p in permutation))
581-
rev_permutation = list(np.argsort(permutation))
581+
rev_permutation = list(map(int, np.argsort(permutation)))
582582
# TODO(jakevdp) avoid dummy indices?
583583
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
584584
data_trans, _ = _bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, spinfo=ct_spinfo)
@@ -865,7 +865,7 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
865865
dims: DotDimensionNumbers = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch))
866866
lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract)))
867867
permutation = list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs
868-
out_axes = list(np.argsort(permutation))
868+
out_axes = list(map(int, np.argsort(permutation)))
869869

870870
# Determine whether efficient approach is possible:
871871
placeholder_data = jnp.empty((lhs_indices.ndim - 2) * (1,) + (lhs_indices.shape[-2],))

0 commit comments

Comments
 (0)