Skip to content

Commit 3075fe4

Browse files
committed
BUGFIX: Comparing ctype ints was causing incorrect implicit types
1 parent 6b543d2 commit 3075fe4

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

arrayfire/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,13 @@ def elements(self):
196196
safe_call(clib.af_get_elements(ct.pointer(num), self.arr))
197197
return num.value
198198

199-
def type(self):
199+
def dtype(self):
200200
dty = ct.c_int(f32.value)
201201
safe_call(clib.af_get_type(ct.pointer(dty), self.arr))
202-
return dty.value
202+
return dty
203+
204+
def type(self):
205+
return self.dtype().value
203206

204207
def dims(self):
205208
d0 = ct.c_longlong(0)

arrayfire/util.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,19 @@ def number_dtype(a):
3636

3737
def implicit_dtype(number, a_dtype):
3838
n_dtype = number_dtype(number)
39-
if n_dtype == f64 and (a_dtype == f32 or a_dtype == c32):
39+
n_value = n_dtype.value
40+
41+
f64v = f64.value
42+
f32v = f32.value
43+
c32v = c32.value
44+
c64v = c64.value
45+
46+
if n_value == f64v and (a_dtype == f32v or a_dtype == c32v):
4047
return f32
41-
if n_dtype == c64 and (a_dtype == f32 or a_dtype == c32):
48+
49+
if n_value == c64v and (a_dtype == f32v or a_dtype == c32v):
4250
return c32
51+
4352
return n_dtype
4453

4554
def dim4_tuple(dims, default=1):

0 commit comments

Comments
 (0)