Skip to content

Commit 37b4965

Browse files
committed
Another batch of Pyrefly fixes
1 parent ad8c8c4 commit 37b4965

27 files changed

+67
-39
lines changed

jax/_src/export/_export.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def deserialize(blob: bytearray) -> Exported:
357357

358358

359359
class _SerializeAuxData(Protocol):
360-
def __call__(self, aux_data: PyTreeAuxData) -> bytes:
360+
def __call__(self, aux_data: PyTreeAuxData, /) -> bytes:
361361
"""Serializes the PyTree node AuxData.
362362
363363
The AuxData is returned by the ``flatten_func`` registered by
@@ -366,7 +366,7 @@ def __call__(self, aux_data: PyTreeAuxData) -> bytes:
366366

367367

368368
class _DeserializeAuxData(Protocol):
369-
def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData:
369+
def __call__(self, serialized_aux_data: bytes, /) -> PyTreeAuxData:
370370
"""Deserializes the PyTree node AuxData.
371371
372372
The result will be passed to ``_BuildFromChildren``.
@@ -497,6 +497,7 @@ def register_namedtuple_serialization(
497497
def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes:
498498
# Store the serialized keys in the serialized auxdata
499499
del aux_data
500+
# pyrefly: ignore[missing-attribute]
500501
return json.dumps(nodetype._fields).encode("utf-8")
501502

502503
def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData:
@@ -901,8 +902,8 @@ def _module_to_bytecode(module: ir.Module) -> bytes:
901902
# Note that this does not verify any JAX custom calls, which are only
902903
# guaranteed 3w of forward compatibility, and only prevents use of new
903904
# StableHLO features from failing on older hardware.
904-
target_version = hlo.get_version_from_compatibility_requirement(
905-
hlo.StablehloCompatibilityRequirement.WEEK_4)
905+
target_version = hlo.get_version_from_compatibility_requirement( # pyrefly: ignore[missing-attribute]
906+
hlo.StablehloCompatibilityRequirement.WEEK_4) # pyrefly: ignore[missing-attribute]
906907
module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore
907908
mlir_str, target_version, xb.get_backend().serialize_with_sdy)
908909
return module_serialized
@@ -995,7 +996,7 @@ def is_token(typ, attrs):
995996
new_arg_attrs = []
996997
for idx in new_main_arg_indices:
997998
new_arg_attr = {}
998-
for attr in arg_attrs[idx]:
999+
for attr in arg_attrs[idx]: # pyrefly: ignore[not-iterable]
9991000
if attr.name == "tf.aliasing_output":
10001001
i = new_main_result_indices.index(attr.attr.value)
10011002
new_arg_attr[attr.name] = ir.IntegerAttr.get(
@@ -1374,7 +1375,7 @@ def flattened_primal_fun_jax(*args_flat):
13741375
if has_named_shardings or mesh:
13751376
vjp_in_shardings = tuple(
13761377
_get_named_sharding(has_named_shardings, named_sharding, # type: ignore
1377-
hlo_sharding, aval, mesh)
1378+
hlo_sharding, aval, mesh) # pyrefly: ignore[bad-argument-type]
13781379
for named_sharding, hlo_sharding, aval in zip(
13791380
itertools.chain(in_named_shardings, out_named_shardings),
13801381
itertools.chain(in_shardings_hlo, out_shardings_hlo),
@@ -1548,8 +1549,8 @@ def _call_exported_impl(*args, exported: Exported):
15481549
def get_mesh_from_symbol(symtab: ir.SymbolTable) -> mesh_lib.AbstractMesh:
15491550
if "mesh" not in symtab:
15501551
return mesh_lib.empty_abstract_mesh
1551-
mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh)
1552-
axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes]
1552+
mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh) # pyrefly: ignore[missing-attribute]
1553+
axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes] # pyrefly: ignore[missing-attribute]
15531554
if not axes:
15541555
return mesh_lib.empty_abstract_mesh
15551556
axes_sizes = tuple(a.size for a in axes)

jax/_src/export/serialization.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from __future__ import annotations
1818

1919
import types
20-
from collections.abc import Callable, Sequence
20+
from collections.abc import Callable, Iterable
2121
import itertools
2222
from functools import partial
23-
from typing import Any, TypeVar
23+
from typing import cast, Any, TypeVar
2424

2525
try:
2626
import flatbuffers
@@ -166,9 +166,10 @@ def _serialize_exported(
166166
def _serialize_array(
167167
builder: flatbuffers.Builder,
168168
serialize_one: Callable[[flatbuffers.Builder, T], int],
169-
elements: Sequence[T],
169+
elements: Iterable[T],
170170
) -> int:
171171
element_offsets = [serialize_one(builder, e) for e in elements]
172+
del elements
172173
ser_flatbuf.PyTreeDefStartChildrenVector(builder, len(element_offsets))
173174
for sc in reversed(element_offsets):
174175
builder.PrependUOffsetTRelative(sc)
@@ -216,8 +217,10 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
216217
partial(_deserialize_aval, scope=scope, sharding=None))
217218
out_avals = _deserialize_tuple(exp.OutAvalsLength, exp.OutAvals,
218219
partial(_deserialize_aval, scope=scope, sharding=None))
219-
in_shardings_hlo, in_shardings = in_shardings, (None,) * len(in_shardings) # type: ignore
220-
out_shardings_hlo, out_shardings = out_shardings, (None,) * len(out_shardings) # type: ignore
220+
in_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], in_shardings)
221+
in_shardings = (None,) * len(in_shardings) # type: ignore
222+
out_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], out_shardings)
223+
out_shardings = (None,) * len(out_shardings) # type: ignore
221224
platforms = _deserialize_tuple(
222225
exp.PlatformsLength,
223226
exp.Platforms,
@@ -313,6 +316,7 @@ def serialize_key(builder, k):
313316
builder, serialize_key, node_data[1]
314317
)
315318
elif node_type in _export.serialization_registry:
319+
assert node_type is not None
316320
kind = ser_flatbuf.PyTreeDefKind.custom
317321
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
318322
custom_name = builder.CreateString(serialized_name)
@@ -497,7 +501,7 @@ def _deserialize_partition_spec_one_axis(
497501
def _serialize_partition_spec(builder: flatbuffers.Builder,
498502
spec: partition_spec.PartitionSpec) -> int:
499503
partitions = _serialize_array(builder, _serialize_partition_spec_one_axis,
500-
spec._partitions)
504+
spec._partitions) # pyrefly: ignore[bad-argument-type]
501505
reduced = _serialize_array(builder, # type: ignore
502506
lambda builder, ps: builder.CreateString(ps),
503507
spec.reduced)
@@ -583,8 +587,9 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue, *,
583587
else:
584588
mem_space = core.MemorySpace.Device
585589

586-
aval = core.ShapedArray(shape, dtype, memory_space=mem_space)
587-
return core.update_aval_with_sharding(aval, sharding)
590+
return core.update_aval_with_sharding(
591+
core.ShapedArray(shape, dtype, memory_space=mem_space), sharding
592+
)
588593

589594

590595
def _serialize_sharding(

jax/_src/export/serialization_generated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyrefly: ignore-errors
1516
# pytype: skip-file
1617
# automatically generated by the FlatBuffers compiler, do not modify
1718

jax/_src/export/shape_poly.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import copy
2929
import operator as op
3030
import tokenize
31-
from typing import Any, Union, overload
31+
from typing import Any, TypeAlias, Union, overload
3232
import warnings
3333

3434
import numpy as np
@@ -52,8 +52,8 @@
5252
DType = Any
5353

5454
# Tuples of terms and their coefficients, sorted with the largest term first.
55-
SortedTerms = Sequence[tuple["_DimTerm", int]]
56-
SortedFactors = Sequence[tuple["_DimFactor", int]]
55+
SortedTerms: TypeAlias = Sequence[tuple["_DimTerm", int]]
56+
SortedFactors: TypeAlias = Sequence[tuple["_DimFactor", int]]
5757

5858
# Normalization rules represent the explicit constraint `t*tk == e` as
5959
# a mapping of `t` to `(e, tk)`.
@@ -583,14 +583,14 @@ def _get_vars(self) -> set[str]:
583583
@overload
584584
@staticmethod
585585
def _linear_combination_sorted_pairs(
586-
e1: SortedTerms, i1: int, f1: int,
587-
e2: SortedTerms, i2: int, f2: int) -> SortedTerms: ... # type: ignore[bad-return-type,unused-ignore]
586+
pairs1: SortedTerms, i1: int, f1: int,
587+
pairs2: SortedTerms, i2: int, f2: int) -> SortedTerms: ... # type: ignore[bad-return-type,unused-ignore]
588588

589589
@overload
590590
@staticmethod
591591
def _linear_combination_sorted_pairs(
592-
e1: SortedFactors, i1: int, f1: int,
593-
e2: SortedFactors, i2: int, f2: int) -> SortedFactors: ... # type: ignore[bad-return-type,unused-ignore]
592+
pairs1: SortedFactors, i1: int, f1: int,
593+
pairs2: SortedFactors, i2: int, f2: int) -> SortedFactors: ... # type: ignore[bad-return-type,unused-ignore]
594594

595595
@staticmethod
596596
def _linear_combination_sorted_pairs(
@@ -862,7 +862,7 @@ def __gt__(self, other: DimSize):
862862
def __lt__(self, other: DimSize):
863863
return not _geq_decision(self, other, lambda: f"'{self}' < '{other}'")
864864

865-
def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]:
865+
def _divmod(self, divisor: DimSize) -> tuple[DimSize, DimSize]:
866866
"""
867867
Floor division with remainder (divmod) generalized to expressions.
868868
If the `divisor` is not a constant, the remainder must be 0.
@@ -1626,7 +1626,7 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
16261626
tok = self.next_tok()
16271627
elif tok.exact_type == tokenize.PLUS:
16281628
tok = self.next_tok()
1629-
acc = None
1629+
acc: DimSize | None = None
16301630
while True:
16311631
t, tok = self.term(tok)
16321632
t_sign = - t if next_t_negated else t
@@ -1640,7 +1640,7 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
16401640
FOLLOW_TERM = FOLLOW_EXPR + [tokenize.PLUS, tokenize.MINUS]
16411641
def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
16421642
# A term is product of factors. Each factor may be raised to an integer power.
1643-
acc = None
1643+
acc: DimSize | None = None
16441644
while True:
16451645
f, tok = self.factor(tok)
16461646
if tok.exact_type == tokenize.CIRCUMFLEX:

jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyrefly: ignore-errors
1516
# ruff: noqa
1617

1718
import datetime

jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_solve_lapack_gtsv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyrefly: ignore-errors
1516
# ruff: noqa
1617

1718
import datetime

jax/_src/internal_test_util/export_back_compat_test_data/cuda_cholesky_solver_potrf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyrefly: ignore-errors
1516
# ruff: noqa
1617

1718
import datetime

jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyrefly: ignore-errors
1516
# ruff: noqa
1617

1718
import datetime

jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyrefly: ignore-errors
1516
# ruff: noqa
1617

1718
import datetime

jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyrefly: ignore-errors
1516
# ruff: noqa
1617

1718
import datetime

0 commit comments

Comments
 (0)