Skip to content

Commit 22d4fc0

Browse files
authored
Merge branch 'main' into where-with-scalars
2 parents 9feef07 + 4f63331 commit 22d4fc0

19 files changed

+414
-545
lines changed

.github/workflows/docs-deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
steps:
1414
- uses: actions/checkout@v4
1515
- name: Download Artifact
16-
uses: dawidd6/action-download-artifact@v6
16+
uses: dawidd6/action-download-artifact@v7
1717
with:
1818
workflow: docs-build.yml
1919
name: docs-build

.github/workflows/publish-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ jobs:
9797
# if: >-
9898
# (github.event_name == 'push' && startsWith(github.ref, 'refs/tags'))
9999
# || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true')
100-
# uses: pypa/[email protected].2
100+
# uses: pypa/[email protected].3
101101
# with:
102102
# repository-url: https://test.pypi.org/legacy/
103103
# print-hash: true
104104

105105
- name: Publish distribution 📦 to PyPI
106106
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
107-
uses: pypa/[email protected].2
107+
uses: pypa/[email protected].3
108108
with:
109109
print-hash: true
110110

.gitignore

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ ENV/
128128
env.bak/
129129
venv.bak/
130130

131-
# Spyder project settings
131+
# Project settings
132+
.idea
133+
.ropeproject
132134
.spyderproject
133135
.spyproject
134-
135-
# Rope project settings
136-
.ropeproject
136+
.vscode
137137

138138
# mkdocs documentation
139139
/site

array_api_strict/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,9 @@
293293

294294
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]
295295

296-
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
296+
from ._searching_functions import argmax, argmin, nonzero, count_nonzero, searchsorted, where
297297

298-
__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]
298+
__all__ += ["argmax", "argmin", "nonzero", "count_nonzero", "searchsorted", "where"]
299299

300300
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
301301

@@ -305,9 +305,9 @@
305305

306306
__all__ += ["argsort", "sort"]
307307

308-
from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
308+
from ._statistical_functions import cumulative_sum, cumulative_prod, max, mean, min, prod, std, sum, var
309309

310-
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
310+
__all__ += ["cumulative_sum", "cumulative_prod", "max", "mean", "min", "prod", "std", "sum", "var"]
311311

312312
from ._utility_functions import all, any, diff
313313

array_api_strict/_array_object.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,6 @@ def __new__(cls, *args, **kwargs):
126126
# These functions are not required by the spec, but are implemented for
127127
# the sake of usability.
128128

129-
def __str__(self: Array, /) -> str:
130-
"""
131-
Performs the operation __str__.
132-
"""
133-
return self._array.__str__().replace("array", "Array")
134-
135129
def __repr__(self: Array, /) -> str:
136130
"""
137131
Performs the operation __repr__.
@@ -149,6 +143,8 @@ def __repr__(self: Array, /) -> str:
149143
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
150144
return prefix + mid + suffix
151145

146+
__str__ = __repr__
147+
152148
# In the future, _allow_array will be set to False, which will disallow
153149
# __array__. This means calling `np.func()` on an array_api_strict array
154150
# will give an error. If we don't explicitly disallow it, NumPy defaults

array_api_strict/_creation_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def from_dlpack(
226226
# Going to wait for upstream numpy support
227227
if device is not _default:
228228
_check_device(device)
229+
else:
230+
device = None
229231
if copy not in [_default, None]:
230232
raise NotImplementedError("The copy argument to from_dlpack is not yet implemented")
231233

array_api_strict/_data_type_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def astype(
4242

4343
if not copy and dtype == x.dtype:
4444
return x
45+
46+
if isdtype(x.dtype, 'complex floating') and not isdtype(dtype, 'complex floating'):
47+
raise TypeError(
48+
f'The Array API standard stipulates that casting {x.dtype} to {dtype} should not be permitted. '
49+
'array-api-strict thus prohibits this conversion.'
50+
)
51+
4552
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device)
4653

4754

@@ -160,6 +167,9 @@ def isdtype(
160167
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
161168
for more details
162169
"""
170+
if not isinstance(dtype, _DType):
171+
raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}")
172+
163173
if isinstance(kind, tuple):
164174
# Disallow nested tuples
165175
if any(isinstance(k, tuple) for k in kind):

0 commit comments

Comments
 (0)