Skip to content

Commit 31e023a

Browse files
committed
Add __array__ for easier numpy interoperability
1 parent 82c69c5 commit 31e023a

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

arrayfire/array.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,12 @@ def __repr__(self):
499499
safe_call(backend.get().af_print_array(self.arr))
500500
return '%s of dimensions %s' % (type(self), self.dims())
501501

502+
def __array__(self):
503+
import numpy as np
504+
res = np.empty(self.dims(), dtype=np.dtype(to_typecode[self.type()]), order='F')
505+
safe_call(backend.get().af_get_data_ptr(ct.c_void_p(res.ctypes.data), self.arr))
506+
return res
507+
502508
def display(a):
503509
expr = inspect.stack()[1][-2]
504510
if (expr is not None):

0 commit comments

Comments
 (0)