Skip to content

Commit 9706463

Browse files
committed
Derive the dtype directly from the number
1 parent 7d6eacc commit 9706463

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

arrayfire/arith.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ def arith_binary_func(lhs, rhs, c_func):
2525

2626
elif (is_number(rhs)):
2727
ldims = dim4_tuple(lhs.dims())
28-
lty = lhs.type()
28+
rty = number_dtype(rhs)
2929
other = array()
30-
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
30+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
3131
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
3232

3333
else:
3434
rdims = dim4_tuple(rhs.dims())
35-
rty = rhs.type()
35+
lty = number_dtype(lhs)
3636
other = array()
37-
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
37+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
3838
safe_call(c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get()))
3939

4040
return out

arrayfire/array.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def binary_func(lhs, rhs, c_func):
6060

6161
if (is_number(rhs)):
6262
ldims = dim4_tuple(lhs.dims())
63-
lty = lhs.type()
63+
rty = number_dtype(rhs)
6464
other = array()
65-
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
65+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
6666
elif not isinstance(rhs, array):
6767
raise TypeError("Invalid parameter to binary function")
6868

@@ -76,9 +76,9 @@ def binary_funcr(lhs, rhs, c_func):
7676

7777
if (is_number(lhs)):
7878
rdims = dim4_tuple(rhs.dims())
79-
rty = rhs.type()
79+
lty = number_dtype(lhs)
8080
other = array()
81-
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
81+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
8282
elif not isinstance(lhs, array):
8383
raise TypeError("Invalid parameter to binary function")
8484

0 commit comments

Comments
 (0)