Skip to content

Commit da24bf5

Browse files
committed
fix: ensure explicit int support and fix missing complex types
1 parent fc36a42 commit da24bf5

File tree

2 files changed

+28
-30
lines changed

2 files changed

+28
-30
lines changed

src/array_api_stubs/_draft/array_object.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ def __pow__(self: array, other: Union[int, float, complex, array], /) -> array:
10371037
----------
10381038
self: array
10391039
array instance whose elements correspond to the exponentiation base. Should have a numeric data type.
1040-
other: Union[int, float, array]
1040+
other: Union[int, float, complex, array]
10411041
other array whose elements correspond to the exponentiation exponent. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
10421042
10431043
Returns
@@ -1083,7 +1083,7 @@ def __setitem__(
10831083
key: Union[
10841084
int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array
10851085
],
1086-
value: Union[int, float, bool, array],
1086+
value: Union[int, float, complex, bool, array],
10871087
/,
10881088
) -> None:
10891089
"""
@@ -1097,17 +1097,15 @@ def __setitem__(
10971097
array instance.
10981098
key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array]
10991099
index key.
1100-
value: Union[int, float, bool, array]
1100+
value: Union[int, float, complex, bool, array]
11011101
value(s) to set. Must be compatible with ``self[key]`` (see :ref:`broadcasting`).
11021102
1103+
Notes
1104+
-----
11031105
1104-
.. note::
1105-
1106-
Setting array values must not affect the data type of ``self``.
1107-
1108-
When ``value`` is a Python scalar (i.e., ``int``, ``float``, ``bool``), behavior must follow specification guidance on mixing arrays with Python scalars (see :ref:`type-promotion`).
1109-
1110-
When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined.
1106+
- Setting array values must not affect the data type of ``self``.
1107+
- When ``value`` is a Python scalar (i.e., ``int``, ``float``, ``complex``, ``bool``), behavior must follow specification guidance on mixing arrays with Python scalars (see :ref:`type-promotion`).
1108+
- When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined.
11111109
"""
11121110

11131111
def __sub__(self: array, other: Union[int, float, complex, array], /) -> array:
@@ -1118,7 +1116,7 @@ def __sub__(self: array, other: Union[int, float, complex, array], /) -> array:
11181116
----------
11191117
self: array
11201118
array instance (minuend array). Should have a numeric data type.
1121-
other: Union[int, float, array]
1119+
other: Union[int, float, complex, array]
11221120
subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
11231121
11241122
Returns
@@ -1136,15 +1134,15 @@ def __sub__(self: array, other: Union[int, float, complex, array], /) -> array:
11361134
Added complex data type support.
11371135
"""
11381136

1139-
def __truediv__(self: array, other: Union[int, float, array], /) -> array:
1137+
def __truediv__(self: array, other: Union[int, float, complex, array], /) -> array:
11401138
r"""
11411139
Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``.
11421140
11431141
Parameters
11441142
----------
11451143
self: array
11461144
array instance. Should have a numeric data type.
1147-
other: Union[int, float, array]
1145+
other: Union[int, float, complex, array]
11481146
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
11491147
11501148
Returns

src/array_api_stubs/_draft/elementwise_functions.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def atan(x: array, /) -> array:
518518
"""
519519

520520

521-
def atan2(x1: Union[array, float], x2: Union[array, float], /) -> array:
521+
def atan2(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array:
522522
"""
523523
Calculates an implementation-dependent approximation of the inverse tangent of the quotient ``x1/x2``, having domain ``[-infinity, +infinity] x [-infinity, +infinity]`` (where the ``x`` notation denotes the set of ordered pairs of elements ``(x1_i, x2_i)``) and codomain ``[-π, +π]``, for each pair of elements ``(x1_i, x2_i)`` of the input arrays ``x1`` and ``x2``, respectively. Each element-wise result is expressed in radians.
524524
@@ -531,9 +531,9 @@ def atan2(x1: Union[array, float], x2: Union[array, float], /) -> array:
531531
532532
Parameters
533533
----------
534-
x1: Union[array, float]
534+
x1: Union[array, int, float]
535535
input array corresponding to the y-coordinates. Should have a real-valued floating-point data type.
536-
x2: Union[array, float]
536+
x2: Union[array, int, float]
537537
input array corresponding to the x-coordinates. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type.
538538
539539
Returns
@@ -824,9 +824,9 @@ def clip(
824824
x: array
825825
input array. Should have a real-valued data type.
826826
min: Optional[Union[int, float, array]]
827-
lower-bound of the range to which to clamp. If ``None``, no lower bound must be applied. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``.
827+
lower-bound of the range to which to clamp. If ``None``, no lower bound must be applied. Must be compatible with ``x`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``.
828828
max: Optional[Union[int, float, array]]
829-
upper-bound of the range to which to clamp. If ``None``, no upper bound must be applied. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``.
829+
upper-bound of the range to which to clamp. If ``None``, no upper bound must be applied. Must be compatible with ``x`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``.
830830
831831
Returns
832832
-------
@@ -883,15 +883,15 @@ def conj(x: array, /) -> array:
883883
"""
884884

885885

886-
def copysign(x1: Union[array, float], x2: Union[array, float], /) -> array:
886+
def copysign(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array:
887887
r"""
888888
Composes a floating-point value with the magnitude of ``x1_i`` and the sign of ``x2_i`` for each element of the input array ``x1``.
889889
890890
Parameters
891891
----------
892-
x1: Union[array, float]
892+
x1: Union[array, int, float]
893893
input array containing magnitudes. Should have a real-valued floating-point data type.
894-
x2: Union[array, float]
894+
x2: Union[array, int, float]
895895
input array whose sign bits are applied to the magnitudes of ``x1``. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type.
896896
897897
Returns
@@ -1436,7 +1436,7 @@ def greater_equal(
14361436
"""
14371437

14381438

1439-
def hypot(x1: Union[array, float], x2: Union[array, float], /) -> array:
1439+
def hypot(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array:
14401440
r"""
14411441
Computes the square root of the sum of squares for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``.
14421442
@@ -1445,9 +1445,9 @@ def hypot(x1: Union[array, float], x2: Union[array, float], /) -> array:
14451445
14461446
Parameters
14471447
----------
1448-
x1: Union[array, float]
1448+
x1: Union[array, int, float]
14491449
first input array. Should have a real-valued floating-point data type.
1450-
x2: Union[array, float]
1450+
x2: Union[array, int, float]
14511451
second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type.
14521452
14531453
Returns
@@ -1869,15 +1869,15 @@ def log10(x: array, /) -> array:
18691869
"""
18701870

18711871

1872-
def logaddexp(x1: Union[array, float], x2: Union[array, float], /) -> array:
1872+
def logaddexp(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array:
18731873
"""
18741874
Calculates the logarithm of the sum of exponentiations ``log(exp(x1) + exp(x2))`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``.
18751875
18761876
Parameters
18771877
----------
1878-
x1: Union[array, float]
1878+
x1: Union[array, int, float]
18791879
first input array. Should have a real-valued floating-point data type.
1880-
x2: Union[array, float]
1880+
x2: Union[array, int, float]
18811881
second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type.
18821882
18831883
Returns
@@ -2163,15 +2163,15 @@ def negative(x: array, /) -> array:
21632163
"""
21642164

21652165

2166-
def nextafter(x1: Union[array, float], x2: Union[array, float], /) -> array:
2166+
def nextafter(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array:
21672167
"""
21682168
Returns the next representable floating-point value for each element ``x1_i`` of the input array ``x1`` in the direction of the respective element ``x2_i`` of the input array ``x2``.
21692169
21702170
Parameters
21712171
----------
2172-
x1: Union[array, float]
2172+
x1: Union[array, int, float]
21732173
first input array. Should have a real-valued floating-point data type.
2174-
x2: Union[array, float]
2174+
x2: Union[array, int, float]
21752175
second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have the same data type as ``x1``.
21762176
21772177
Returns

0 commit comments

Comments
 (0)