Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
comment: false
ignore:
- "src/array_api_extra/_typing"
503 changes: 221 additions & 282 deletions pixi.lock

Large diffs are not rendered by default.

24 changes: 22 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,23 @@ array-api-extra = { path = ".", editable = true }

[tool.pixi.feature.lint.dependencies]
pre-commit = "*"
mypy = "*"
pylint = "*"
# import dependencies for mypy:
array-api-strict = "*"
numpy = "*"
pytest = "*"

[tool.pixi.feature.lint.pypi-dependencies]
basedmypy = "*"
basedpyright = "*"

[tool.pixi.feature.lint.tasks]
pre-commit-install = { cmd = "pre-commit install" }
pre-commit = { cmd = "pre-commit run -v --all-files --show-diff-on-failure" }
mypy = { cmd = "mypy", cwd = "." }
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
lint = { depends-on = ["pre-commit", "pylint", "mypy"] }
pyright = { cmd = "basedpyright", cwd = "." }
lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] }

[tool.pixi.feature.tests.dependencies]
pytest = ">=6"
Expand Down Expand Up @@ -165,13 +169,29 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
# array-api#589
disallow_any_expr = false

[[tool.mypy.overrides]]
module = "array_api_extra.*"
disallow_untyped_defs = true
disallow_incomplete_defs = true


# pyright

[tool.pyright]
include = ["src", "tests"]
pythonVersion = "3.10"
pythonPlatform = "All"
typeCheckingMode = "strict"

# array-api#589
reportAny = false
reportExplicitAny = false
reportUnknownMemberType = false


# Ruff

[tool.ruff]
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import typing
import warnings
from typing import TYPE_CHECKING

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]
Expand Down
3 changes: 2 additions & 1 deletion src/array_api_extra/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from types import ModuleType
from typing import Any

Array = Any # To be changed to a Protocol later (see array-api#589)
# To be changed to a Protocol later (see array-api#589)
Array = Any # type: ignore[no-any-explicit]

__all__ = ["Array", "ModuleType"]
15 changes: 8 additions & 7 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)
# To be changed to a Protocol later (see array-api#589)
Array = Any # type: ignore[no-any-explicit]


class TestAtLeastND:
Expand Down Expand Up @@ -131,7 +132,7 @@ def test_1d(self):

@pytest.mark.parametrize("n", range(1, 10))
@pytest.mark.parametrize("offset", range(1, 10))
def test_create_diagonal(self, n, offset):
def test_create_diagonal(self, n: int, offset: int):
# from scipy._lib tests
rng = np.random.default_rng(2347823)
one = xp.asarray(1.0)
Expand Down Expand Up @@ -180,9 +181,9 @@ def test_basic(self):
assert_array_equal(kron(a, b, xp=xp), k)

def test_kron_smoke(self):
a = xp.ones([3, 3])
b = xp.ones([3, 3])
k = xp.ones([9, 9])
a = xp.ones((3, 3))
b = xp.ones((3, 3))
k = xp.ones((9, 9))

assert_array_equal(kron(a, b, xp=xp), k)

Expand All @@ -197,7 +198,7 @@ def test_kron_smoke(self):
((2, 0, 0, 2), (2, 0, 2)),
],
)
def test_kron_shape(self, shape_a, shape_b):
def test_kron_shape(self, shape_a: tuple[int], shape_b: tuple[int]):
a = xp.ones(shape_a)
b = xp.ones(shape_b)
normalised_shape_a = xp.asarray(
Expand Down Expand Up @@ -271,7 +272,7 @@ def test_simple(self):
assert_allclose(w, xp.flip(w, axis=0))

@pytest.mark.parametrize("x", [0, 1 + 3j])
def test_dtype(self, x):
def test_dtype(self, x: int | complex):
with pytest.raises(ValueError, match="real floating data type"):
sinc(xp.asarray(x), xp=xp)

Expand Down