Skip to content

Commit 3ad1985

Browse files
committed
Bumped mypy and ruff versions used by pre-commit
1 parent 0d7ef9c commit 3ad1985

File tree

7 files changed

+11
-10
lines changed

7 files changed

+11
-10
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ repos:
2626
files: \.py$
2727

2828
- repo: https://github.com/astral-sh/ruff-pre-commit
29-
rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1
29+
rev: 8983acb92ee4b01924893632cf90af926fa608f0 # frozen: v0.7.0
3030
hooks:
3131
- id: ruff
3232

3333
- repo: https://github.com/pre-commit/mirrors-mypy
34-
rev: 'd4911cfb7f1010759fde68da196036feeb25b99d' # frozen: v1.11.2
34+
rev: '102bbee94061ff02fd361ec29c27b7cb26582f5f' # frozen: v1.12.2
3535
hooks:
3636
- id: mypy
3737
files: (jax/|tests/typing_test\.py)

jax/_src/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3270,7 +3270,7 @@ def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
32703270
) -> pp.Doc:
32713271
rule = (_pp_eqn if not settings.custom_pp_eqn_rules else
32723272
pp_eqn_rules.get(eqn.primitive, _pp_eqn))
3273-
doc = rule(eqn, context, settings) # type: ignore[operator]
3273+
doc = rule(eqn, context, settings)
32743274
user_frame = source_info_util.user_frame(eqn.source_info)
32753275
return doc if user_frame is None else pp.source_map(doc, user_frame)
32763276

jax/_src/interpreters/partial_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2738,7 +2738,7 @@ def inline_jaxpr_into_trace(
27382738
outvars = [Var('', v.aval) for v in eqn.outvars]
27392739
src_ = (src if not eqn.source_info.name_stack else
27402740
src.replace(name_stack=src.name_stack + eqn.source_info.name_stack))
2741-
trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) # type: ignore
2741+
trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_))
27422742
map(env.setdefault, eqn.outvars, outvars)
27432743

27442744
tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars],

jax/_src/interpreters/pxla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1750,7 +1750,7 @@ def _get_and_check_device_assignment(
17501750
elif first_sharding_info is None:
17511751
final_device_assignment = (_get_default_device(),)
17521752
else:
1753-
final_device_assignment = first_sharding_info[0]
1753+
final_device_assignment = first_sharding_info[0] # type: ignore
17541754
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
17551755

17561756
MaybeSharding = Union[JSharding, UnspecifiedValue]

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6477,7 +6477,7 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
64776477
if (not dtypes.issubdtype(start_dtype, np.integer) and
64786478
not dtypes.issubdtype(start_dtype, dtypes.extended)):
64796479
ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil
6480-
start = ceil_(start).astype(int) # type: ignore[operator]
6480+
start = ceil_(start).astype(int)
64816481
return lax.iota(dtype, start) # type: ignore[arg-type]
64826482
else:
64836483
if step is None and start == 0 and stop is not None:

jax/_src/pallas/pallas_call.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def _pallas_call_impl_interpret(
254254
num_inout_blocks = len(block_args) + len(out)
255255
grid_start_indices = (jnp.int32(0),) * len(grid)
256256
if grid:
257-
num_iterations = reduce(jnp.multiply, grid)
257+
num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type]
258258
else:
259259
# Base case is always one iteration when grid is ()
260260
num_iterations = 1
@@ -1174,7 +1174,7 @@ def pallas_call_checkify_oob_grid(error: checkify.Error,
11741174
)
11751175
grid_start_indices = (jnp.int32(0),) * len(grid)
11761176
if grid:
1177-
num_iterations = reduce(jnp.multiply, grid)
1177+
num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type]
11781178
else:
11791179
# Base case is always one iteration when grid is ()
11801180
num_iterations = 1

jax/_src/pjit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def _extract_implicit_args(
779779
args[d1.val] = d2
780780
assert core.same_referent(args[d1.val], d2)
781781
assert all(x is not None for x in args)
782-
return [x for x, (_, e) in zip(args, in_type) if not e] # pytype: disable=bad-return-type
782+
return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore
783783

784784
def _flat_axes_specs(abstracted_axes, *args, **kwargs
785785
) -> list[pe.AbstractedAxesSpec] | None:
@@ -1545,6 +1545,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
15451545
else:
15461546
resolved_in_shardings.append(arg_s)
15471547
else:
1548+
assert isinstance(arg_s, sharding.Sharding)
15481549
if dispatch.is_single_device_sharding(arg_s):
15491550
resolved_in_shardings.append(UNSPECIFIED)
15501551
else:
@@ -1581,7 +1582,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
15811582
not isinstance(arg_s, PmapSharding) and
15821583
not op_shardings.are_op_shardings_equal(
15831584
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore
1584-
arg_s._to_xla_hlo_sharding(arg.ndim))):
1585+
arg_s._to_xla_hlo_sharding(arg.ndim))): # type: ignore
15851586
raise ValueError('Sharding passed to pjit does not match the sharding '
15861587
'on the respective arg. '
15871588
f'Got pjit sharding: {pjit_in_s},\n'

0 commit comments

Comments
 (0)