Skip to content

Commit dfc6076

Browse files
author
jax authors
committed
Merge pull request #21744 from superbobry:typing
PiperOrigin-RevId: 641339815
2 parents 136289e + 0786da8 commit dfc6076

File tree

2 files changed

+28
-34
lines changed

2 files changed

+28
-34
lines changed

jax/_src/internal_test_util/test_harnesses.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@
6262
from jax._src.lib import xla_client
6363
from jax._src import random as jax_random
6464

65+
# mypy generates a lot of false positive due to re-assigned variables.
66+
# mypy: disable-error-code="assignment, no-redef"
67+
6568
# The code in this file relies on the values of some flags that are defined by
6669
# jtu. Note that the following can not always be moved to a test file since
6770
# then the test file has to import jtu first (to define the flags) which is not
@@ -172,9 +175,9 @@ def __init__(self,
172175
self.group_name = jtu.sanitize_test_name(group_name)
173176
self.name = jtu.sanitize_test_name(name)
174177
self.fullname = self.name if self.group_name is None else f"{self.group_name}_{self.name}"
175-
self.fun = fun # type: ignore[assignment]
178+
self.fun = fun
176179
self.arg_descriptors = arg_descriptors
177-
self.rng_factory = rng_factory # type: ignore[assignment]
180+
self.rng_factory = rng_factory
178181
self.jax_unimplemented = jax_unimplemented
179182
self.dtype = dtype
180183
self.params = params
@@ -2060,18 +2063,17 @@ def _make_slice_harness(name,
20602063
define(
20612064
lax.slice_p,
20622065
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}",
2063-
# type: ignore
20642066
lax.slice,
20652067
[
2066-
RandArg(shape, dtype), # type: ignore
2067-
StaticArg(start_indices), # type: ignore
2068-
StaticArg(limit_indices), # type: ignore
2068+
RandArg(shape, dtype),
2069+
StaticArg(start_indices),
2070+
StaticArg(limit_indices),
20692071
StaticArg(strides)
2070-
], # type: ignore
2072+
],
20712073
dtype=dtype,
2072-
shape=shape, # type: ignore
2073-
start_indices=start_indices, # type: ignore
2074-
limit_indices=limit_indices) # type: ignore
2074+
shape=shape,
2075+
start_indices=start_indices,
2076+
limit_indices=limit_indices)
20752077

20762078

20772079
# Test first all dtypes
@@ -2161,17 +2163,16 @@ def _make_dynamic_slice_harness(name,
21612163
define(
21622164
lax.dynamic_slice_p,
21632165
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}",
2164-
# type: ignore
21652166
lax.dynamic_slice,
21662167
[
2167-
RandArg(shape, dtype), # type: ignore
2168+
RandArg(shape, dtype),
21682169
np.array(list(start_indices)),
21692170
StaticArg(tuple(map(operator.sub, limit_indices, start_indices)))
2170-
], # type: ignore
2171+
],
21712172
dtype=dtype,
2172-
shape=shape, # type: ignore
2173-
start_indices=start_indices, # type: ignore
2174-
limit_indices=limit_indices, # type: ignore
2173+
shape=shape,
2174+
start_indices=start_indices,
2175+
limit_indices=limit_indices,
21752176
enable_xla=enable_xla)
21762177

21772178

@@ -2218,19 +2219,19 @@ def _make_dynamic_update_slice_harness(name,
22182219
define(
22192220
lax.dynamic_update_slice_p,
22202221
(
2221-
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
2222+
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}"
22222223
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
22232224
f"_{start_indices=}_{enable_xla=}"),
22242225
lax.dynamic_update_slice,
22252226
[
2226-
RandArg(shape, dtype), # type: ignore
2227-
RandArg(update_shape, dtype), # type: ignore
2227+
RandArg(shape, dtype),
2228+
RandArg(update_shape, dtype),
22282229
np.array(start_indices)
2229-
], # type: ignore
2230+
],
22302231
dtype=dtype,
2231-
shape=shape, # type: ignore
2232-
start_indices=start_indices, # type: ignore
2233-
update_shape=update_shape, # type: ignore
2232+
shape=shape,
2233+
start_indices=start_indices,
2234+
update_shape=update_shape,
22342235
enable_xla=enable_xla)
22352236

22362237

@@ -2261,12 +2262,12 @@ def _make_squeeze_harness(name,
22612262
dtype=np.float32):
22622263
define(
22632264
lax.squeeze_p,
2264-
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore
2265+
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}",
22652266
lax.squeeze,
2266-
[RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type]
2267+
[RandArg(shape, dtype), StaticArg(dimensions)],
22672268
dtype=dtype,
22682269
arg_shape=shape,
2269-
dimensions=dimensions) # type: ignore[has-type]
2270+
dimensions=dimensions)
22702271

22712272

22722273
# Test first all dtypes
@@ -3312,6 +3313,7 @@ def _make_conv_harness(name,
33123313
lhs_dilation=lhs_dilation,
33133314
rhs_dilation=rhs_dilation)
33143315

3316+
key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]]
33153317
key_types = [((4,), np.uint32)]
33163318
if config.enable_x64.value:
33173319
key_types.append(((2,), np.uint64))

pyproject.toml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,6 @@ module = [
4444
]
4545
ignore_missing_imports = true
4646

47-
[[tool.mypy.overrides]]
48-
module = [
49-
"jax.interpreters.autospmd",
50-
"jax.lax.lax_parallel",
51-
"jax._src.internal_test_util.test_harnesses",
52-
]
53-
ignore_errors = true
54-
5547
[tool.pytest.ini_options]
5648
markers = [
5749
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",

0 commit comments

Comments
 (0)