Skip to content

Commit 74c509f

Browse files
committed
adress PR comments
1 parent 45121c5 commit 74c509f

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

src/array_api_extra/_delegation.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_torch_namespace,
1616
)
1717
from ._lib._utils._compat import device as get_device
18-
from ._lib._utils._helpers import asarrays
18+
from ._lib._utils._helpers import asarrays, eager_shape
1919
from ._lib._utils._typing import Array, DType
2020

2121
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
@@ -364,10 +364,7 @@ def partition(
364364
raise TypeError(msg)
365365
if axis is None:
366366
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
367-
size = a.shape[axis]
368-
if size is None:
369-
msg = "Array dimensions must be known"
370-
raise ValueError(msg)
367+
(size,) = eager_shape(a, axis)
371368
if not (0 <= kth < size):
372369
msg = f"kth(={kth}) out of bounds [0 {size})"
373370
raise ValueError(msg)
@@ -445,10 +442,7 @@ def argpartition(
445442
raise TypeError(msg)
446443
if axis is None:
447444
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
448-
size = a.shape[axis]
449-
if size is None:
450-
msg = "Array dimensions must be known"
451-
raise ValueError(msg)
445+
(size,) = eager_shape(a, axis)
452446
if not (0 <= kth < size):
453447
msg = f"kth(={kth}) out of bounds [0 {size})"
454448
raise ValueError(msg)

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,22 +250,31 @@ def ndindex(*x: int) -> Generator[tuple[int, ...]]:
250250
yield *i, j
251251

252252

253-
def eager_shape(x: Array, /) -> tuple[int, ...]:
253+
def eager_shape(x: Array, /, axis: int | None = None) -> tuple[int, ...]:
254254
"""
255255
Return shape of an array. Raise if shape is not fully defined.
256256
257257
Parameters
258258
----------
259259
x : Array
260260
Input array.
261+
axis : int, optional
262+
If provided, only returns the tuple (shape[axis],).
261263
262264
Returns
263265
-------
264266
tuple[int, ...]
265267
Shape of the array.
266268
"""
267269
shape = x.shape
268-
# Dask arrays uses non-standard NaN instead of None
270+
if axis is not None:
271+
s = shape[axis]
272+
# Dask arrays uses non-standard NaN instead of None
273+
if s is None or math.isnan(s):
274+
msg = f"Unsupported lazy shape for axis {axis}"
275+
raise TypeError(msg)
276+
return (s,)
277+
269278
if any(s is None or math.isnan(s) for s in shape):
270279
msg = "Unsupported lazy shape"
271280
raise TypeError(msg)

tests/test_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,14 @@ def test_eager_shape(xp: ModuleType, library: Backend):
182182
# Lazy arrays, like Dask, have an eager shape until you slice them with
183183
# a lazy boolean mask
184184
assert eager_shape(a) == a.shape == (3,)
185+
assert eager_shape(a, axis=0) == a.shape == (3,)
185186

186187
b = a[a > 2]
187188
if library is Backend.DASK:
188189
with pytest.raises(TypeError, match="Unsupported lazy shape"):
189190
_ = eager_shape(b)
191+
with pytest.raises(TypeError, match="Unsupported lazy shape"):
192+
_ = eager_shape(b, axis=0)
190193
# FIXME can't test use case for None in the shape until we add support for
191194
# other lazy backends
192195
else:

0 commit comments

Comments
 (0)