Skip to content

Commit b0429c7

Browse files
committed
Merge pull request #22 from FilipeMaia/array_fixes
Avoid promoting arrays to 64-bit
2 parents dd77152 + 66695c7 commit b0429c7

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

arrayfire/arith.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def arith_binary_func(lhs, rhs, c_func):
2525

2626
elif (is_number(rhs)):
2727
ldims = dim4_tuple(lhs.dims())
28-
rty = number_dtype(rhs)
28+
rty = implicit_dtype(rhs, lhs.type())
2929
other = array()
3030
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
3131
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
3232

3333
else:
3434
rdims = dim4_tuple(rhs.dims())
35-
lty = number_dtype(lhs)
35+
lty = implicit_dtype(lhs, rhs.type())
3636
other = array()
3737
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
3838
safe_call(c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get()))

arrayfire/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def binary_func(lhs, rhs, c_func):
6060

6161
if (is_number(rhs)):
6262
ldims = dim4_tuple(lhs.dims())
63-
rty = number_dtype(rhs)
63+
rty = implicit_dtype(rhs, lhs.type())
6464
other = array()
6565
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
6666
elif not isinstance(rhs, array):
@@ -76,7 +76,7 @@ def binary_funcr(lhs, rhs, c_func):
7676

7777
if (is_number(lhs)):
7878
rdims = dim4_tuple(rhs.dims())
79-
lty = number_dtype(lhs)
79+
lty = implicit_dtype(lhs, rhs.type())
8080
other = array()
8181
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
8282
elif not isinstance(lhs, array):

arrayfire/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def number_dtype(a):
3434
else:
3535
return to_dtype[a.dtype.char]
3636

37+
def implicit_dtype(number, a_dtype):
38+
n_dtype = number_dtype(number)
39+
if n_dtype == f64 and (a_dtype == f32 or a_dtype == c32):
40+
return f32
41+
if n_dtype == c64 and (a_dtype == f32 or a_dtype == c32):
42+
return c32
43+
return n_dtype
44+
3745
def dim4_tuple(dims, default=1):
3846
assert(isinstance(dims, tuple))
3947

0 commit comments

Comments
 (0)