Skip to content

Commit dd77152

Browse files
committed
Merge pull request #21 from FilipeMaia/array_fixes
Bugfixes to ArrayFire
2 parents 58e030a + a87481a commit dd77152

File tree

3 files changed

+37
-14
lines changed

3 files changed

+37
-14
lines changed

arrayfire/arith.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,24 @@ def arith_binary_func(lhs, rhs, c_func):
1818
is_right_array = isinstance(rhs, array)
1919

2020
if not (is_left_array or is_right_array):
21-
TypeError("Atleast one input needs to be of type arrayfire.array")
21+
raise TypeError("Atleast one input needs to be of type arrayfire.array")
2222

2323
elif (is_left_array and is_right_array):
2424
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, bcast.get()))
2525

2626
elif (is_number(rhs)):
2727
ldims = dim4_tuple(lhs.dims())
28-
lty = lhs.type()
28+
rty = number_dtype(rhs)
2929
other = array()
30-
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
30+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
3131
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
3232

3333
else:
3434
rdims = dim4_tuple(rhs.dims())
35-
rty = rhs.type()
35+
lty = number_dtype(lhs)
3636
other = array()
37-
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
38-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
37+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
38+
safe_call(c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get()))
3939

4040
return out
4141

arrayfire/array.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def binary_func(lhs, rhs, c_func):
6060

6161
if (is_number(rhs)):
6262
ldims = dim4_tuple(lhs.dims())
63-
lty = lhs.type()
63+
rty = number_dtype(rhs)
6464
other = array()
65-
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
65+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
6666
elif not isinstance(rhs, array):
6767
raise TypeError("Invalid parameter to binary function")
6868

@@ -76,9 +76,9 @@ def binary_funcr(lhs, rhs, c_func):
7676

7777
if (is_number(lhs)):
7878
rdims = dim4_tuple(rhs.dims())
79-
rty = rhs.type()
79+
lty = number_dtype(lhs)
8080
other = array()
81-
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
81+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
8282
elif not isinstance(lhs, array):
8383
raise TypeError("Invalid parameter to binary function")
8484

@@ -179,13 +179,18 @@ def __init__(self, src=None, dims=(0,), type_char=None):
179179

180180
def copy(self):
181181
out = array()
182-
safe_call(clib.af_retain_array(ct.pointer(out.arr), self.arr))
182+
safe_call(clib.af_copy_array(ct.pointer(out.arr), self.arr))
183183
return out
184184

185185
def __del__(self):
186186
if (self.arr.value != 0):
187187
clib.af_release_array(self.arr)
188188

189+
def device_ptr(self):
190+
ptr = ctypes.c_void_p(0)
191+
clib.af_get_device_ptr(ct.pointer(ptr), self.arr)
192+
return ptr.value
193+
189194
def elements(self):
190195
num = ct.c_ulonglong(0)
191196
safe_call(clib.af_get_elements(ct.pointer(num), self.arr))

arrayfire/util.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ def dim4(d0=1, d1=1, d2=1, d3=1):
2222
def is_number(a):
2323
return isinstance(a, numbers.Number)
2424

25+
def number_dtype(a):
26+
if isinstance(a, bool):
27+
return b8
28+
if isinstance(a, int):
29+
return s64
30+
elif isinstance(a, float):
31+
return f64
32+
elif isinstance(a, complex):
33+
return c64
34+
else:
35+
return to_dtype[a.dtype.char]
36+
2537
def dim4_tuple(dims, default=1):
2638
assert(isinstance(dims, tuple))
2739

@@ -59,7 +71,9 @@ def get_version():
5971
'i' : s32,
6072
'I' : u32,
6173
'l' : s64,
62-
'L' : u64}
74+
'L' : u64,
75+
'F' : c32,
76+
'D' : c64}
6377

6478
to_typecode = {f32.value : 'f',
6579
f64.value : 'd',
@@ -68,7 +82,9 @@ def get_version():
6882
s32.value : 'i',
6983
u32.value : 'I',
7084
s64.value : 'l',
71-
u64.value : 'L'}
85+
u64.value : 'L',
86+
c32.value : 'F',
87+
c64.value : 'D'}
7288

7389
to_c_type = {f32.value : ct.c_float,
7490
f64.value : ct.c_double,
@@ -77,4 +93,6 @@ def get_version():
7793
s32.value : ct.c_int,
7894
u32.value : ct.c_uint,
7995
s64.value : ct.c_longlong,
80-
u64.value : ct.c_ulonglong}
96+
u64.value : ct.c_ulonglong,
97+
c32.value : ct.c_float * 2,
98+
c64.value : ct.c_double * 2}

0 commit comments

Comments
 (0)