Skip to content

Commit b0c88f4

Browse files
committed
TYP: auto-plagiarize the optypean Just* types
1 parent 205c967 commit b0c88f4

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

array_api_compat/common/_typing.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
from collections.abc import Mapping
44
from types import ModuleType as Namespace
5-
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
5+
from typing import (
6+
TYPE_CHECKING,
7+
Literal,
8+
Protocol,
9+
TypeAlias,
10+
TypedDict,
11+
TypeVar,
12+
final,
13+
)
614

715
if TYPE_CHECKING:
816
from _typeshed import Incomplete
@@ -21,6 +29,37 @@
2129
_T_co = TypeVar("_T_co", covariant=True)
2230

2331

32+
# These "Just" types are equivalent to the `Just` type from the `optype` library,
33+
# apart from them not being `@runtime_checkable`.
34+
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
35+
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
36+
@final
37+
class JustInt(Protocol):
38+
@property
39+
def __class__(self, /) -> type[int]: ...
40+
@__class__.setter
41+
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
42+
43+
44+
@final
45+
class JustFloat(Protocol):
46+
@property
47+
def __class__(self, /) -> type[float]: ...
48+
@__class__.setter
49+
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
50+
51+
52+
@final
53+
class JustComplex(Protocol):
54+
@property
55+
def __class__(self, /) -> type[complex]: ...
56+
@__class__.setter
57+
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
58+
59+
60+
#
61+
62+
2463
class NestedSequence(Protocol[_T_co]):
2564
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
2665
def __len__(self, /) -> int: ...
@@ -121,6 +160,8 @@ class DTypesAll(DTypesBool, DTypesNumeric):
121160
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
122161
DTypesAny: TypeAlias = Mapping[str, DType]
123162

163+
NormOrder: TypeAlias = JustFloat | Literal[-2, -1, 1, 2]
164+
124165

125166
__all__ = [
126167
"Array",
@@ -140,6 +181,9 @@ class DTypesAll(DTypesBool, DTypesNumeric):
140181
"Device",
141182
"HasShape",
142183
"Namespace",
184+
"JustInt",
185+
"JustFloat",
186+
"JustComplex",
143187
"NestedSequence",
144188
"SupportsArrayNamespace",
145189
"SupportsBufferProtocol",

0 commit comments

Comments
 (0)