Skip to content

Commit 71c6a41

Browse files
committed
WIP: try fixing ceil dtypes
1 parent 9815b8d commit 71c6a41

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

array_api_compat/common/_aliases.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,13 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
528528

529529

530530
def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
531-
if xp.issubdtype(x.dtype, xp.integer):
532-
return x
533-
return xp.ceil(x, **kwargs)
531+
# if xp.issubdtype(x.dtype, xp.integer):
532+
# return x
533+
result = xp.ceil(x, **kwargs)
534+
if result.dtype != x.dtype:
535+
# numpy < 2: ceil(int array) is float
536+
result = xp.astype(result, x.dtype)
537+
return result
534538

535539

536540
def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:

0 commit comments

Comments
 (0)