Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/array_api_typing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Static typing support for the array API standard."""

__all__ = (
__all__ = ( # noqa: RUF022
# ==================
# Array
"Array",
"HasArrayNamespace",
"HasDType",
Expand All @@ -9,6 +11,12 @@
"HasShape",
"HasSize",
"HasTranspose",
# ==================
# Namespace
"ArrayNamespace",
"DoesAsType",
"HasAsType",
# ==================
"__version__",
"__version_tuple__",
)
Expand All @@ -23,4 +31,5 @@
HasSize,
HasTranspose,
)
from ._namespace import ArrayNamespace, DoesAsType, HasAsType
from ._version import version as __version__, version_tuple as __version_tuple__
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These empty files snuck in. I'll remove when rebasing after #34 is in.

Empty file.
Empty file.
112 changes: 112 additions & 0 deletions src/array_api_typing/_namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Protocol, TypeVar

from ._array import HasDType

__all__ = (
"ArrayNamespace",
# Data Type Functions
"DoesAsType",
"HasAsType",
)

DTypeT = TypeVar("DTypeT")
ToDTypeT = TypeVar("ToDTypeT")

# ===================================================================
# Creation Functions
# TODO: arange, asarray, empty, empty_like, eye, from_dlpack, full, full_like,
# linspace, meshgrid, ones, ones_like, tril, triu, zeros, zeros_like

# ===================================================================
# Data Type Functions
# TODO: broadcast_arrays, broadcast_to, can_cast, finfo, iinfo,
# result_type


class DoesAsType(Protocol):
"""Copies an array to a specified data type irrespective of Type Promotion Rules rules.

Note:
Casting floating-point ``NaN`` and ``infinity`` values to integral data
types is not specified and is implementation-dependent.

Note:
When casting a boolean input array to a numeric data type, a value of
`True` must cast to a numeric value equal to ``1``, and a value of
`False` must cast to a numeric value equal to ``0``.

When casting a numeric input array to bool, a value of ``0`` must cast
to `False`, and a non-zero value must cast to `True`.

Args:
x: The array to cast.
dtype: desired data type.
copy: specifies whether to copy an array when the specified `dtype`
matches the data type of the input array `x`. If `True`, a newly
allocated array must always be returned. If `False` and the
specified `dtype` matches the data type of the input array, the
input array must be returned; otherwise, a newly allocated must be
returned. Default: `True`.

""" # noqa: E501

def __call__(
self, x: HasDType[DTypeT], dtype: ToDTypeT, /, *, copy: bool = True
) -> HasDType[ToDTypeT]: ...


class HasAsType(Protocol):
"""Protocol for namespaces that have an ``astype`` function."""

astype: DoesAsType


# ===================================================================
# Element-wise Functions
# TODO: abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and,
# bitwise_invert, bitwise_left_shift, bitwise_or, bitwise_right_shift,
# bitwise_xor, ceil, cos, cosh, divide, equal, exp, exp2, expm1, floor,
# floor_divide, greater, greater_equal, isfinite, isinf, isnan, less,
# less_equal, log, log1p, log2, log10, logical_and, logical_not, logical_or,
# logical_xor, multiply, negative, not_equal, positive, pow, remainder, round,
# sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc


# ===================================================================
# Linear Algebra Functions
# TODO: matmul, matrix_transpose, tensordot, vecdot

# ===================================================================
# Manipulation Functions
# TODO: concat, expand_dims, flip, permute_dims, reshape, roll, squeeze, stack

# ===================================================================
# Searching Functions
# TODO: argmax, argmin, nonzero, where

# ===================================================================
# Set Functions
# TODO: unique_all, unique_counts, unique_inverse, unique_values

# ===================================================================
# Sorting Functions
# TODO: argsort, sort

# ===================================================================
# Statistical Functions
# TODO: max, mean, min, prod, std, sum, var

# ===================================================================
# Utility Functions
# TODO: all, any

# ===================================================================
# Full Namespace


class ArrayNamespace(
# Data Type Functions
HasAsType,
Protocol,
):
"""Protocol for an Array API-compatible namespace."""
6 changes: 6 additions & 0 deletions tests/integration/test_numpy2p0.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,9 @@ assert_type(x_b.size, int | None)
assert_type(x_f32.T, xpt.Array[np.dtype[F32]])
assert_type(x_i32.T, xpt.Array[np.dtype[I32]])
assert_type(x_b.T, xpt.Array[np.dtype[B]])

##############################################################################
# Tests on Namespace Functions

_: xpt.DoesAsType = np.astype
_: xpt.HasAsType = np
Loading