|
62 | 62 | from jax._src.lib import xla_client
|
63 | 63 | from jax._src import random as jax_random
|
64 | 64 |
|
| 65 | +# mypy generates a lot of false positive due to re-assigned variables. |
| 66 | +# mypy: disable-error-code="assignment, no-redef" |
| 67 | + |
65 | 68 | # The code in this file relies on the values of some flags that are defined by
|
66 | 69 | # jtu. Note that the following can not always be moved to a test file since
|
67 | 70 | # then the test file has to import jtu first (to define the flags) which is not
|
@@ -172,9 +175,9 @@ def __init__(self,
|
172 | 175 | self.group_name = jtu.sanitize_test_name(group_name)
|
173 | 176 | self.name = jtu.sanitize_test_name(name)
|
174 | 177 | self.fullname = self.name if self.group_name is None else f"{self.group_name}_{self.name}"
|
175 |
| - self.fun = fun # type: ignore[assignment] |
| 178 | + self.fun = fun |
176 | 179 | self.arg_descriptors = arg_descriptors
|
177 |
| - self.rng_factory = rng_factory # type: ignore[assignment] |
| 180 | + self.rng_factory = rng_factory |
178 | 181 | self.jax_unimplemented = jax_unimplemented
|
179 | 182 | self.dtype = dtype
|
180 | 183 | self.params = params
|
@@ -2060,18 +2063,17 @@ def _make_slice_harness(name,
|
2060 | 2063 | define(
|
2061 | 2064 | lax.slice_p,
|
2062 | 2065 | f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}",
|
2063 |
| - # type: ignore |
2064 | 2066 | lax.slice,
|
2065 | 2067 | [
|
2066 |
| - RandArg(shape, dtype), # type: ignore |
2067 |
| - StaticArg(start_indices), # type: ignore |
2068 |
| - StaticArg(limit_indices), # type: ignore |
| 2068 | + RandArg(shape, dtype), |
| 2069 | + StaticArg(start_indices), |
| 2070 | + StaticArg(limit_indices), |
2069 | 2071 | StaticArg(strides)
|
2070 |
| - ], # type: ignore |
| 2072 | + ], |
2071 | 2073 | dtype=dtype,
|
2072 |
| - shape=shape, # type: ignore |
2073 |
| - start_indices=start_indices, # type: ignore |
2074 |
| - limit_indices=limit_indices) # type: ignore |
| 2074 | + shape=shape, |
| 2075 | + start_indices=start_indices, |
| 2076 | + limit_indices=limit_indices) |
2075 | 2077 |
|
2076 | 2078 |
|
2077 | 2079 | # Test first all dtypes
|
@@ -2161,17 +2163,16 @@ def _make_dynamic_slice_harness(name,
|
2161 | 2163 | define(
|
2162 | 2164 | lax.dynamic_slice_p,
|
2163 | 2165 | f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}",
|
2164 |
| - # type: ignore |
2165 | 2166 | lax.dynamic_slice,
|
2166 | 2167 | [
|
2167 |
| - RandArg(shape, dtype), # type: ignore |
| 2168 | + RandArg(shape, dtype), |
2168 | 2169 | np.array(list(start_indices)),
|
2169 | 2170 | StaticArg(tuple(map(operator.sub, limit_indices, start_indices)))
|
2170 |
| - ], # type: ignore |
| 2171 | + ], |
2171 | 2172 | dtype=dtype,
|
2172 |
| - shape=shape, # type: ignore |
2173 |
| - start_indices=start_indices, # type: ignore |
2174 |
| - limit_indices=limit_indices, # type: ignore |
| 2173 | + shape=shape, |
| 2174 | + start_indices=start_indices, |
| 2175 | + limit_indices=limit_indices, |
2175 | 2176 | enable_xla=enable_xla)
|
2176 | 2177 |
|
2177 | 2178 |
|
@@ -2218,19 +2219,19 @@ def _make_dynamic_update_slice_harness(name,
|
2218 | 2219 | define(
|
2219 | 2220 | lax.dynamic_update_slice_p,
|
2220 | 2221 | (
|
2221 |
| - f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore |
| 2222 | + f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" |
2222 | 2223 | f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
|
2223 | 2224 | f"_{start_indices=}_{enable_xla=}"),
|
2224 | 2225 | lax.dynamic_update_slice,
|
2225 | 2226 | [
|
2226 |
| - RandArg(shape, dtype), # type: ignore |
2227 |
| - RandArg(update_shape, dtype), # type: ignore |
| 2227 | + RandArg(shape, dtype), |
| 2228 | + RandArg(update_shape, dtype), |
2228 | 2229 | np.array(start_indices)
|
2229 |
| - ], # type: ignore |
| 2230 | + ], |
2230 | 2231 | dtype=dtype,
|
2231 |
| - shape=shape, # type: ignore |
2232 |
| - start_indices=start_indices, # type: ignore |
2233 |
| - update_shape=update_shape, # type: ignore |
| 2232 | + shape=shape, |
| 2233 | + start_indices=start_indices, |
| 2234 | + update_shape=update_shape, |
2234 | 2235 | enable_xla=enable_xla)
|
2235 | 2236 |
|
2236 | 2237 |
|
@@ -2261,12 +2262,12 @@ def _make_squeeze_harness(name,
|
2261 | 2262 | dtype=np.float32):
|
2262 | 2263 | define(
|
2263 | 2264 | lax.squeeze_p,
|
2264 |
| - f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore |
| 2265 | + f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", |
2265 | 2266 | lax.squeeze,
|
2266 |
| - [RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type] |
| 2267 | + [RandArg(shape, dtype), StaticArg(dimensions)], |
2267 | 2268 | dtype=dtype,
|
2268 | 2269 | arg_shape=shape,
|
2269 |
| - dimensions=dimensions) # type: ignore[has-type] |
| 2270 | + dimensions=dimensions) |
2270 | 2271 |
|
2271 | 2272 |
|
2272 | 2273 | # Test first all dtypes
|
@@ -3312,6 +3313,7 @@ def _make_conv_harness(name,
|
3312 | 3313 | lhs_dilation=lhs_dilation,
|
3313 | 3314 | rhs_dilation=rhs_dilation)
|
3314 | 3315 |
|
| 3316 | +key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]] |
3315 | 3317 | key_types = [((4,), np.uint32)]
|
3316 | 3318 | if config.enable_x64.value:
|
3317 | 3319 | key_types.append(((2,), np.uint64))
|
|
0 commit comments