Skip to content

Commit 6524ad8

Browse files
committed
fix mypy errors
1 parent 8bd5927 commit 6524ad8

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/array_api_extra/_funcs.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55

66
import operator
77
import warnings
8-
9-
# https://github.com/pylint-dev/pylint/issues/10112
10-
from collections.abc import Callable # pylint: disable=import-error
11-
from typing import ClassVar, Literal
8+
from collections.abc import Callable
9+
from types import ModuleType
10+
from typing import ClassVar, Literal, cast
1211

1312
from ._lib import _compat, _utils
1413
from ._lib._compat import (
1514
array_namespace,
1615
is_jax_array,
1716
is_writeable_array,
1817
)
19-
from ._lib._typing import Array, Index, ModuleType
18+
from ._lib._typing import Array, Index
2019

2120
__all__ = [
2221
"at",
@@ -779,7 +778,7 @@ def _update_common(
779778
if copy:
780779
if is_jax_array(x):
781780
# Use JAX's at[]
782-
func = getattr(x.at[idx], at_op)
781+
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op))
783782
return func(y), None
784783
# Emulate at[] behaviour for non-JAX arrays
785784
# with a copy followed by an update

tests/test_at.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections.abc import Generator
1+
from collections.abc import Callable, Generator
22
from contextlib import contextmanager
33
from importlib import import_module
4+
from typing import cast
45

56
import numpy as np
67
import pytest
@@ -100,7 +101,8 @@ def test_update_ops(
100101
pytest.skip("at() does not support updates on sparse arrays")
101102

102103
with assert_copy(array, expect_copy):
103-
y = getattr(at(array)[1:], op)(arg, **kwargs)
104+
func = cast(Callable[..., Array], getattr(at(array)[1:], op)) # type: ignore[no-any-explicit]
105+
y = func(arg, **kwargs)
104106
assert isinstance(y, type(array))
105107
assert_array_equal(y, expect)
106108

@@ -153,6 +155,6 @@ def test_iops_incompatible_dtype(op: str, copy: bool):
153155
to dtype('int64') with casting rule 'same_kind'
154156
"""
155157
a = np.asarray([2, 4])
156-
func = getattr(at(a)[:], op)
158+
func = cast(Callable[..., Array], getattr(at(a)[:], op)) # type: ignore[no-any-explicit]
157159
with pytest.raises(TypeError, match="Cannot cast ufunc"):
158160
func(1.1, copy=copy)

0 commit comments

Comments
 (0)