Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def deserialize(blob: bytearray) -> Exported:


class _SerializeAuxData(Protocol):
def __call__(self, aux_data: PyTreeAuxData) -> bytes:
def __call__(self, aux_data: PyTreeAuxData, /) -> bytes:
"""Serializes the PyTree node AuxData.

The AuxData is returned by the ``flatten_func`` registered by
Expand All @@ -366,7 +366,7 @@ def __call__(self, aux_data: PyTreeAuxData) -> bytes:


class _DeserializeAuxData(Protocol):
def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData:
def __call__(self, serialized_aux_data: bytes, /) -> PyTreeAuxData:
"""Deserializes the PyTree node AuxData.

The result will be passed to ``_BuildFromChildren``.
Expand Down Expand Up @@ -497,6 +497,7 @@ def register_namedtuple_serialization(
def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes:
# Store the serialized keys in the serialized auxdata
del aux_data
# pyrefly: ignore[missing-attribute]
return json.dumps(nodetype._fields).encode("utf-8")

def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData:
Expand Down Expand Up @@ -901,8 +902,8 @@ def _module_to_bytecode(module: ir.Module) -> bytes:
# Note that this does not verify any JAX custom calls, which are only
# guaranteed 3w of forward compatibility, and only prevents use of new
# StableHLO features from failing on older hardware.
target_version = hlo.get_version_from_compatibility_requirement(
hlo.StablehloCompatibilityRequirement.WEEK_4)
target_version = hlo.get_version_from_compatibility_requirement( # pyrefly: ignore[missing-attribute]
hlo.StablehloCompatibilityRequirement.WEEK_4) # pyrefly: ignore[missing-attribute]
module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore
mlir_str, target_version, xb.get_backend().serialize_with_sdy)
return module_serialized
Expand Down Expand Up @@ -995,7 +996,7 @@ def is_token(typ, attrs):
new_arg_attrs = []
for idx in new_main_arg_indices:
new_arg_attr = {}
for attr in arg_attrs[idx]:
for attr in arg_attrs[idx]: # pyrefly: ignore[not-iterable]
if attr.name == "tf.aliasing_output":
i = new_main_result_indices.index(attr.attr.value)
new_arg_attr[attr.name] = ir.IntegerAttr.get(
Expand Down Expand Up @@ -1374,7 +1375,7 @@ def flattened_primal_fun_jax(*args_flat):
if has_named_shardings or mesh:
vjp_in_shardings = tuple(
_get_named_sharding(has_named_shardings, named_sharding, # type: ignore
hlo_sharding, aval, mesh)
hlo_sharding, aval, mesh) # pyrefly: ignore[bad-argument-type]
for named_sharding, hlo_sharding, aval in zip(
itertools.chain(in_named_shardings, out_named_shardings),
itertools.chain(in_shardings_hlo, out_shardings_hlo),
Expand Down Expand Up @@ -1548,8 +1549,8 @@ def _call_exported_impl(*args, exported: Exported):
def get_mesh_from_symbol(symtab: ir.SymbolTable) -> mesh_lib.AbstractMesh:
if "mesh" not in symtab:
return mesh_lib.empty_abstract_mesh
mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh)
axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes]
mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh) # pyrefly: ignore[missing-attribute]
axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes] # pyrefly: ignore[missing-attribute]
if not axes:
return mesh_lib.empty_abstract_mesh
axes_sizes = tuple(a.size for a in axes)
Expand Down
21 changes: 13 additions & 8 deletions jax/_src/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from __future__ import annotations

import types
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterable
import itertools
from functools import partial
from typing import Any, TypeVar
from typing import cast, Any, TypeVar

try:
import flatbuffers
Expand Down Expand Up @@ -166,9 +166,10 @@ def _serialize_exported(
def _serialize_array(
builder: flatbuffers.Builder,
serialize_one: Callable[[flatbuffers.Builder, T], int],
elements: Sequence[T],
elements: Iterable[T],
) -> int:
element_offsets = [serialize_one(builder, e) for e in elements]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since elements is now an interable rather than a sequence, I'd consider adding del elements in the next line to emphasize that it may be consumed by this iteration.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done.

del elements
ser_flatbuf.PyTreeDefStartChildrenVector(builder, len(element_offsets))
for sc in reversed(element_offsets):
builder.PrependUOffsetTRelative(sc)
Expand Down Expand Up @@ -216,8 +217,10 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
partial(_deserialize_aval, scope=scope, sharding=None))
out_avals = _deserialize_tuple(exp.OutAvalsLength, exp.OutAvals,
partial(_deserialize_aval, scope=scope, sharding=None))
in_shardings_hlo, in_shardings = in_shardings, (None,) * len(in_shardings) # type: ignore
out_shardings_hlo, out_shardings = out_shardings, (None,) * len(out_shardings) # type: ignore
in_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], in_shardings)
in_shardings = (None,) * len(in_shardings) # type: ignore
out_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], out_shardings)
out_shardings = (None,) * len(out_shardings) # type: ignore
platforms = _deserialize_tuple(
exp.PlatformsLength,
exp.Platforms,
Expand Down Expand Up @@ -313,6 +316,7 @@ def serialize_key(builder, k):
builder, serialize_key, node_data[1]
)
elif node_type in _export.serialization_registry:
assert node_type is not None
kind = ser_flatbuf.PyTreeDefKind.custom
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
custom_name = builder.CreateString(serialized_name)
Expand Down Expand Up @@ -497,7 +501,7 @@ def _deserialize_partition_spec_one_axis(
def _serialize_partition_spec(builder: flatbuffers.Builder,
spec: partition_spec.PartitionSpec) -> int:
partitions = _serialize_array(builder, _serialize_partition_spec_one_axis,
spec._partitions)
spec._partitions) # pyrefly: ignore[bad-argument-type]
reduced = _serialize_array(builder, # type: ignore
lambda builder, ps: builder.CreateString(ps),
spec.reduced)
Expand Down Expand Up @@ -583,8 +587,9 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue, *,
else:
mem_space = core.MemorySpace.Device

aval = core.ShapedArray(shape, dtype, memory_space=mem_space)
return core.update_aval_with_sharding(aval, sharding)
return core.update_aval_with_sharding(
core.ShapedArray(shape, dtype, memory_space=mem_space), sharding
)


def _serialize_sharding(
Expand Down
1 change: 1 addition & 0 deletions jax/_src/export/serialization_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# pytype: skip-file
# automatically generated by the FlatBuffers compiler, do not modify

Expand Down
30 changes: 15 additions & 15 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import copy
import operator as op
import tokenize
from typing import Any, Union, overload
from typing import Any, TypeAlias, Union, overload
import warnings

import numpy as np
Expand All @@ -52,8 +52,8 @@
DType = Any

# Tuples of terms and their coefficients, sorted with the largest term first.
SortedTerms = Sequence[tuple["_DimTerm", int]]
SortedFactors = Sequence[tuple["_DimFactor", int]]
SortedTerms: TypeAlias = Sequence[tuple["_DimTerm", int]]
SortedFactors: TypeAlias = Sequence[tuple["_DimFactor", int]]

# Normalization rules represent the explicit constraint `t*tk == e` as
# a mapping of `t` to `(e, tk)`.
Expand Down Expand Up @@ -583,14 +583,14 @@ def _get_vars(self) -> set[str]:
@overload
@staticmethod
def _linear_combination_sorted_pairs(
e1: SortedTerms, i1: int, f1: int,
e2: SortedTerms, i2: int, f2: int) -> SortedTerms: ... # type: ignore[bad-return-type,unused-ignore]
pairs1: SortedTerms, i1: int, f1: int,
pairs2: SortedTerms, i2: int, f2: int) -> SortedTerms: ... # type: ignore[bad-return-type,unused-ignore]

@overload
@staticmethod
def _linear_combination_sorted_pairs(
e1: SortedFactors, i1: int, f1: int,
e2: SortedFactors, i2: int, f2: int) -> SortedFactors: ... # type: ignore[bad-return-type,unused-ignore]
pairs1: SortedFactors, i1: int, f1: int,
pairs2: SortedFactors, i2: int, f2: int) -> SortedFactors: ... # type: ignore[bad-return-type,unused-ignore]

@staticmethod
def _linear_combination_sorted_pairs(
Expand Down Expand Up @@ -862,7 +862,7 @@ def __gt__(self, other: DimSize):
def __lt__(self, other: DimSize):
return not _geq_decision(self, other, lambda: f"'{self}' < '{other}'")

def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]:
def _divmod(self, divisor: DimSize) -> tuple[DimSize, DimSize]:
"""
Floor division with remainder (divmod) generalized to expressions.
If the `divisor` is not a constant, the remainder must be 0.
Expand Down Expand Up @@ -1201,9 +1201,9 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
dtypes.register_weak_scalar_type(_DimExpr)

def _convertible_to_int(p: DimSize) -> bool:
def _convertible_to_int(p: Any) -> bool:
try:
op.index(p) # type: ignore
op.index(p)
return True
except:
return False
Expand All @@ -1218,7 +1218,7 @@ def _ensure_poly(p: DimSize,
return _DimExpr(((_DimTerm_one, op.index(p)),), scope)
raise TypeError(f"Symbolic dimension {operation_name} not supported for {p}.")

def _convertible_to_poly(p: DimSize) -> bool:
def _convertible_to_poly(p: Any) -> bool:
return isinstance(p, _DimExpr) or _convertible_to_int(p)

def is_symbolic_dim(p: DimSize) -> bool:
Expand Down Expand Up @@ -1626,11 +1626,11 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
tok = self.next_tok()
elif tok.exact_type == tokenize.PLUS:
tok = self.next_tok()
acc = None
acc: DimSize | None = None
while True:
t, tok = self.term(tok)
t_sign = - t if next_t_negated else t
acc = acc + t_sign if acc is not None else t_sign # type: ignore[operator]
acc = acc + t_sign if acc is not None else t_sign # type: ignore
if tok.exact_type in self.FOLLOW_EXPR:
return acc, tok
next_t_negated = (tok.exact_type == tokenize.MINUS)
Expand All @@ -1640,7 +1640,7 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
FOLLOW_TERM = FOLLOW_EXPR + [tokenize.PLUS, tokenize.MINUS]
def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
# A term is product of factors. Each factor may be raised to an integer power.
acc = None
acc: DimSize | None = None
while True:
f, tok = self.factor(tok)
if tok.exact_type == tokenize.CIRCUMFLEX:
Expand All @@ -1649,7 +1649,7 @@ def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
power, tok = self.integer(tok)
f = f ** power

acc = acc * f if acc is not None else f # type: ignore[operator]
acc = acc * f if acc is not None else f # type: ignore
if tok.exact_type in self.FOLLOW_TERM:
return acc, tok # type: ignore[bad-return-type,unused-ignore]
tok = self.consume_token(tok, tokenize.STAR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
import datetime
import numpy
array = numpy.array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyrefly: ignore-errors
# ruff: noqa

import datetime
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/internal_test_util/test_harnesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def dtypes_to_str(dtype_list: Sequence[DType], empty_means_all=False) -> str:
if not dtype_list and empty_means_all:
return "all"

names = {np.dtype(dt).name for dt in dtype_list}
names: set[str] = {np.dtype(dt).name for dt in dtype_list}
signed = {"int8", "int16", "int32", "int64"}
if signed <= names:
names = (names - signed) | {"signed"}
Expand Down Expand Up @@ -2378,7 +2378,7 @@ def _make_select_and_scatter_add_harness(name,
padding=padding)


for dtype in set(jtu.dtypes.all) - {np.complex64, np.complex128}:
for dtype in set(jtu.dtypes.all) - {np.complex64, np.complex128}: # pyrefly: ignore[unsupported-operation]
_make_select_and_scatter_add_harness("dtypes", dtype=dtype)

# Validate different reduction primitives
Expand All @@ -2402,7 +2402,7 @@ def _make_select_and_scatter_add_harness(name,
_make_select_and_scatter_add_harness("window_strides", window_strides=(1, 2, 3))

# Validate dtypes on TPU
for dtype in set(jtu.dtypes.all) - {
for dtype in set(jtu.dtypes.all) - { # pyrefly: ignore[unsupported-operation]
np.bool_, np.complex64, np.complex128, np.int8, np.uint8}:
for window_strides, window_dimensions, nb_inactive_dims in [((1, 2, 1),
(1, 3, 1), 2)]:
Expand Down Expand Up @@ -2485,7 +2485,7 @@ def reducer(*args):
init_val = np.array(init_value, dtype=dtype)
init_values = [init_val]
if nr_operands == 2:
init_values.append(np.int32(0.))
init_values.append(np.array(0, dtype=np.int32))
return lax.reduce(args[0:nr_operands], tuple(init_values),
computation, dimensions)
define(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def _conv_general_dilated_lower(
return complex_conv(ctx, lhs, rhs)

lhs_spec, rhs_spec, out_spec = dimension_numbers
dnums = hlo.ConvDimensionNumbers.get(
dnums = hlo.ConvDimensionNumbers.get( # pyrefly: ignore[missing-attribute]
input_batch_dimension=lhs_spec[0],
input_feature_dimension=lhs_spec[1],
input_spatial_dimensions=list(lhs_spec[2:]),
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
# TODO: https://github.com/openxla/stablehlo/issues/1366
raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU")
return [
# pyrefly: ignore[missing-attribute]
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_array(fft_lengths)).result
]
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lazy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __getattr__(name: str) -> Any:
value = importlib.import_module(f"{package_name}.{name}")
# Update module-level globals to avoid calling ``__getattr__`` again
# for this ``name``.
assert owner_name is not None # pyrefly#40
setattr(sys.modules[owner_name], name, value)
return value
raise AttributeError(f"module '{package_name}' has no attribute '{name}'")
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _xla_gc_callback(*args):
try:
from jaxlib.mlir import ir # type: ignore[import-not-found]
except ImportError:
from mlir import ir
from mlir import ir # type: ignore[import-not-found]
mosaic_gpu_dialect.init_cc_mlir(ir)

import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401
Expand Down
Loading
Loading