Skip to content

Commit 9c0435a

Browse files
committed
Fix bug when scalar is empty returns None
1 parent cdb7a92 commit 9c0435a

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,13 @@ def shape(self) -> ShapeType:
502502
ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), self.arr))
503503
return (d0.value, d1.value, d2.value, d3.value)[:self.ndim] # Skip passing None values
504504

505-
def scalar(self) -> int | float | bool | complex:
505+
def scalar(self) -> None | int | float | bool | complex:
506506
"""
507507
Return the first element of the array
508508
"""
509-
# BUG seg fault on empty array
509+
if self.is_empty():
510+
return None
511+
510512
out = self.dtype.c_type()
511513
safe_call(backend.get().af_get_scalar(ctypes.pointer(out), self.arr))
512514
return out.value # type: ignore[no-any-return] # FIXME

arrayfire/array_api/tests/test_array_object.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,24 @@ def test_array_getitem() -> None:
107107
# TODO add more tests for different dtypes
108108

109109

110+
def test_scalar() -> None:
111+
array = Array([1, 2, 3])
112+
assert array[1].scalar() == 2
113+
114+
115+
def test_scalar_is_empty() -> None:
116+
array = Array()
117+
assert array.scalar() is None
118+
119+
110120
def test_array_to_list() -> None:
111-
# TODO add test of to_ctypes_array
112-
assert Array([1, 2, 3]).to_list() == [1, 2, 3]
113-
assert Array().to_list() == []
121+
array = Array([1, 2, 3])
122+
assert array.to_list() == [1, 2, 3]
123+
124+
125+
def test_array_to_list_is_empty() -> None:
126+
array = Array()
127+
assert array.to_list() == []
114128

115129

116130
def test_array_add() -> None:

0 commit comments

Comments
 (0)