Skip to content

Commit 71c38ca

Browse files
committed
FEAT/TEST: Adding support for getting data back to the host
1 parent 13aebff commit 71c38ca

File tree

6 files changed

+85
-13
lines changed

6 files changed

+85
-13
lines changed

arrayfire/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,6 @@
5353
del get_indices
5454
del get_assign_dims
5555
del slice_to_length
56+
del ctype_to_lists
57+
del to_dtype
58+
del to_c_type

arrayfire/array.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def binary_funcr(lhs, rhs, c_func):
8383

8484
return out
8585

86+
def transpose(a, conj=False):
87+
out = array()
88+
safe_call(clib.af_transpose(ct.pointer(out.arr), a.arr, conj))
89+
return out
90+
91+
def transpose_inplace(a, conj=False):
92+
safe_call(clib.af_transpose_inplace(a.arr, conj))
93+
8694
class seq(ct.Structure):
8795
_fields_ = [("begin", ct.c_double),
8896
("end" , ct.c_double),
@@ -163,6 +171,17 @@ def slice_to_length(key, dim):
163171

164172
return int(((tkey[1] - tkey[0] - 1) / tkey[2]) + 1)
165173

174+
def ctype_to_lists(ctype_arr, dim, shape, offset=0):
175+
if (dim == 0):
176+
return list(ctype_arr[offset : offset + shape[0]])
177+
else:
178+
dim_len = shape[dim]
179+
res = [[]] * dim_len
180+
for n in range(dim_len):
181+
res[n] = ctype_to_lists(ctype_arr, dim - 1, shape, offset)
182+
offset += shape[0]
183+
return res
184+
166185
def get_assign_dims(key, idims):
167186
dims = [1]*4
168187

@@ -518,6 +537,31 @@ def __setitem__(self, key, val):
518537
except RuntimeError as e:
519538
raise IndexError(str(e))
520539

540+
def to_ctype(self, row_major=False, return_shape=False):
541+
tmp = transpose(self) if row_major else self
542+
ctype_type = to_c_type[self.type()] * self.elements()
543+
res = ctype_type()
544+
safe_call(clib.af_get_data_ptr(ct.pointer(res), self.arr))
545+
if (return_shape):
546+
return res, self.dims()
547+
else:
548+
return res
549+
550+
def to_array(self, row_major=False, return_shape=False):
551+
res = self.to_ctype(row_major, return_shape)
552+
553+
host = __import__("array")
554+
h_type = to_typecode[self.type()]
555+
556+
if (return_shape):
557+
return host.array(h_type, res[0]), res[1]
558+
else:
559+
return host.array(h_type, res)
560+
561+
def to_list(self, row_major=False):
562+
ct_array, shape = self.to_ctype(row_major, True)
563+
return ctype_to_lists(ct_array, len(shape) - 1, shape)
564+
521565
def display(a):
522566
expr = inspect.stack()[1][-2]
523567
if (expr is not None):

arrayfire/blas.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,3 @@ def dot(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
3939
safe_call(clib.af_dot(ct.pointer(out.arr), lhs.arr, rhs.arr,\
4040
lhs_opts, rhs_opts))
4141
return out
42-
43-
def transpose(a, conj=False):
44-
out = array()
45-
safe_call(clib.af_transpose(ct.pointer(out.arr), a.arr, conj))
46-
return out
47-
48-
def transpose_inplace(a, conj=False):
49-
safe_call(clib.af_transpose_inplace(a.arr, conj))

arrayfire/util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,21 @@ def get_version():
6060
'I' : u32,
6161
'l' : s64,
6262
'L' : u64}
63+
64+
to_typecode = {f32.value : 'f',
65+
f64.value : 'd',
66+
b8.value : 'b',
67+
u8.value : 'B',
68+
s32.value : 'i',
69+
u32.value : 'I',
70+
s64.value : 'l',
71+
u64.value : 'L'}
72+
73+
to_c_type = {f32.value : ct.c_float,
74+
f64.value : ct.c_double,
75+
b8.value : ct.c_char,
76+
u8.value : ct.c_ubyte,
77+
s32.value : ct.c_int,
78+
u32.value : ct.c_uint,
79+
s64.value : ct.c_longlong,
80+
u64.value : ct.c_ulonglong}

tests/simple_array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,23 @@
6666
af.display(a)
6767
a[idx, idx] = af.randu(3,3)
6868
af.display(a)
69+
70+
af.display(af.transpose(a))
71+
72+
af.transpose_inplace(a)
73+
af.display(a)
74+
75+
c = a.to_ctype()
76+
for n in range(a.elements()):
77+
print(c[n])
78+
79+
c,s = a.to_ctype(True, True)
80+
for n in range(a.elements()):
81+
print(c[n])
82+
print(s)
83+
84+
arr = a.to_array()
85+
lst = a.to_list(True)
86+
87+
print(arr)
88+
print(lst)

tests/simple_blas.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,3 @@
1919

2020
b = af.randu(5,1)
2121
af.display(af.dot(b,b))
22-
23-
af.display(af.transpose(a))
24-
25-
af.transpose_inplace(a)
26-
af.display(a)

0 commit comments

Comments
 (0)