Skip to content

Commit 61d94d4

Browse files
Implemented __complex__ magic method
1 parent 101bcc2 commit 61d94d4

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,15 @@ cdef class usm_ndarray:
515515
"only size-1 arrays can be converted to Python scalars"
516516
)
517517

518+
def __complex__(self):
519+
if self.size == 1:
520+
mem_view = dpmem.as_usm_memory(self)
521+
return mem_view.copy_to_host().view(self.dtype).__complex__()
522+
523+
raise ValueError(
524+
"only size-1 arrays can be converted to Python scalars"
525+
)
526+
518527
def __int__(self):
519528
if self.size == 1:
520529
mem_view = dpmem.as_usm_memory(self)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def test_properties():
114114
assert isinstance(X.ndim, numbers.Integral)
115115

116116

117-
@pytest.mark.parametrize("func", [bool, float, int])
118-
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
117+
@pytest.mark.parametrize("func", [bool, float, int, complex])
118+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
119119
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
120120
def test_copy_scalar_with_func(func, shape, dtype):
121121
X = dpt.usm_ndarray(shape, dtype=dtype)
@@ -124,8 +124,10 @@ def test_copy_scalar_with_func(func, shape, dtype):
124124
assert func(X) == func(Y)
125125

126126

127-
@pytest.mark.parametrize("method", ["__bool__", "__float__", "__int__"])
128-
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
127+
@pytest.mark.parametrize(
128+
"method", ["__bool__", "__float__", "__int__", "__complex__"]
129+
)
130+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
129131
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
130132
def test_copy_scalar_with_method(method, shape, dtype):
131133
X = dpt.usm_ndarray(shape, dtype=dtype)
@@ -134,7 +136,7 @@ def test_copy_scalar_with_method(method, shape, dtype):
134136
assert getattr(X, method)() == getattr(Y, method)()
135137

136138

137-
@pytest.mark.parametrize("func", [bool, float, int])
139+
@pytest.mark.parametrize("func", [bool, float, int, complex])
138140
@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)])
139141
def test_copy_scalar_invalid_shape(func, shape):
140142
X = dpt.usm_ndarray(shape)

0 commit comments

Comments
 (0)