Skip to content

Commit 1dc4955

Browse files
committed
Reinstate explicit bool | complex
1 parent 2713c7c commit 1dc4955

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,12 @@
6464

6565
_copy_default = object()
6666

67+
6768
# asarray also adds the copy keyword, which is not present in numpy 1.0.
6869
def asarray(
69-
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
70+
obj: (
71+
Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
72+
),
7073
/,
7174
*,
7275
dtype: Optional[DType] = None,

array_api_compat/dask/array/_aliases.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def arange(
135135

136136
# asarray also adds the copy keyword, which is not present in numpy 1.0.
137137
def asarray(
138-
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
138+
obj: (
139+
Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
140+
),
139141
/,
140142
*,
141143
dtype: Optional[DType] = None,

array_api_compat/numpy/_aliases.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._typing import Array, Device, DType
1010

1111
import numpy as np
12+
1213
bool = np.bool_
1314

1415
# Basic renames
@@ -61,19 +62,23 @@
6162
tensordot = get_xp(np)(_aliases.tensordot)
6263
sign = get_xp(np)(_aliases.sign)
6364

65+
6466
def _supports_buffer_protocol(obj):
6567
try:
6668
memoryview(obj)
6769
except TypeError:
6870
return False
6971
return True
7072

73+
7174
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7275
# asarray() is different enough between numpy, cupy, and dask, the logic
7376
# complicated enough that it's easier to define it separately for each module
7477
# rather than trying to combine everything into one function in common/
7578
def asarray(
76-
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
79+
obj: (
80+
Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
81+
),
7782
/,
7883
*,
7984
dtype: Optional[DType] = None,
@@ -119,11 +124,7 @@ def astype(
119124

120125
# count_nonzero returns a python int for axis=None and keepdims=False
121126
# https://github.com/numpy/numpy/issues/17562
122-
def count_nonzero(
123-
x : Array,
124-
axis=None,
125-
keepdims=False
126-
) -> Array:
127+
def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
127128
result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
128129
if axis is None and not keepdims:
129130
return np.asarray(result)

0 commit comments

Comments
 (0)