@@ -22,24 +22,26 @@ def _create_array(buf, numdims, idims, dtype):
22
22
out_arr = ct .c_void_p (0 )
23
23
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
24
24
safe_call (backend .get ().af_create_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
25
- numdims , ct .pointer (c_dims ), dtype ))
25
+ numdims , ct .pointer (c_dims ), dtype . value ))
26
26
return out_arr
27
27
28
28
def _create_empty_array (numdims , idims , dtype ):
29
29
out_arr = ct .c_void_p (0 )
30
30
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
31
31
safe_call (backend .get ().af_create_handle (ct .pointer (out_arr ),
32
- numdims , ct .pointer (c_dims ), dtype ))
32
+ numdims , ct .pointer (c_dims ), dtype . value ))
33
33
return out_arr
34
34
35
- def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
35
+ def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = Dtype . f32 ):
36
36
"""
37
37
Internal function to create a C array. Should not be used externall.
38
38
"""
39
39
40
40
if not isinstance (dtype , ct .c_int ):
41
41
if isinstance (dtype , int ):
42
42
dtype = ct .c_int (dtype )
43
+ elif isinstance (dtype , Dtype ):
44
+ dtype = ct .c_int (dtype .value )
43
45
else :
44
46
raise TypeError ("Invalid dtype" )
45
47
@@ -50,15 +52,15 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
50
52
c_real = ct .c_double (val .real )
51
53
c_imag = ct .c_double (val .imag )
52
54
53
- if (dtype != c32 and dtype != c64 ):
54
- dtype = c32
55
+ if (dtype . value != Dtype . c32 . value and dtype . value != Dtype . c64 . value ):
56
+ dtype = Dtype . c32 . value
55
57
56
58
safe_call (backend .get ().af_constant_complex (ct .pointer (out ), c_real , c_imag ,
57
- 4 , ct .pointer (dims ), dtype ))
58
- elif dtype == s64 :
59
+ 4 , ct .pointer (dims ), dtype ))
60
+ elif dtype . value == Dtype . s64 . value :
59
61
c_val = ct .c_longlong (val .real )
60
62
safe_call (backend .get ().af_constant_long (ct .pointer (out ), c_val , 4 , ct .pointer (dims )))
61
- elif dtype == u64 :
63
+ elif dtype . value == Dtype . u64 . value :
62
64
c_val = ct .c_ulonglong (val .real )
63
65
safe_call (backend .get ().af_constant_ulong (ct .pointer (out ), c_val , 4 , ct .pointer (dims )))
64
66
else :
@@ -76,7 +78,7 @@ def _binary_func(lhs, rhs, c_func):
76
78
ldims = dim4_to_tuple (lhs .dims ())
77
79
rty = implicit_dtype (rhs , lhs .type ())
78
80
other = Array ()
79
- other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], rty )
81
+ other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], rty . value )
80
82
elif not isinstance (rhs , Array ):
81
83
raise TypeError ("Invalid parameter to binary function" )
82
84
@@ -92,7 +94,7 @@ def _binary_funcr(lhs, rhs, c_func):
92
94
rdims = dim4_to_tuple (rhs .dims ())
93
95
lty = implicit_dtype (lhs , rhs .type ())
94
96
other = Array ()
95
- other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], lty )
97
+ other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], lty . value )
96
98
elif not isinstance (lhs , Array ):
97
99
raise TypeError ("Invalid parameter to binary function" )
98
100
@@ -186,7 +188,7 @@ class Array(BaseArray):
186
188
dims : optional: tuple of ints. default: (0,)
187
189
- When using the default values of `dims`, the dims are caclulated as `len(src)`
188
190
189
- dtype: optional: str or ctypes.c_int . default: None.
191
+ dtype: optional: str or arrayfire.Dtype . default: None.
190
192
- if str, must be one of the following:
191
193
- 'f' for float
192
194
- 'd' for double
@@ -198,18 +200,18 @@ class Array(BaseArray):
198
200
- 'L' for unsigned 64 bit integer
199
201
- 'F' for 32 bit complex number
200
202
- 'D' for 64 bit complex number
201
- - if ctypes.c_int , must be one of the following:
202
- - f32 for float
203
- - f64 for double
204
- - b8 for bool
205
- - u8 for unsigned char
206
- - s32 for signed 32 bit integer
207
- - u32 for unsigned 32 bit integer
208
- - s64 for signed 64 bit integer
209
- - u64 for unsigned 64 bit integer
210
- - c32 for 32 bit complex number
211
- - c64 for 64 bit complex number
212
- - if None, f32 is assumed
203
+ - if arrayfire.Dtype , must be one of the following:
204
+ - Dtype. f32 for float
205
+ - Dtype. f64 for double
206
+ - Dtype. b8 for bool
207
+ - Dtype. u8 for unsigned char
208
+ - Dtype. s32 for signed 32 bit integer
209
+ - Dtype. u32 for unsigned 32 bit integer
210
+ - Dtype. s64 for signed 64 bit integer
211
+ - Dtype. u64 for unsigned 64 bit integer
212
+ - Dtype. c32 for 32 bit complex number
213
+ - Dtype. c64 for 64 bit complex number
214
+ - if None, Dtype. f32 is assumed
213
215
214
216
Attributes
215
217
-----------
@@ -281,7 +283,6 @@ def __init__(self, src=None, dims=(0,), dtype=None):
281
283
type_char = None
282
284
283
285
_type_char = 'f'
284
- dtype = f32
285
286
286
287
backend .lock ()
287
288
@@ -318,8 +319,6 @@ def __init__(self, src=None, dims=(0,), dtype=None):
318
319
319
320
_type_char = type_char
320
321
321
- print (_type_char )
322
-
323
322
else :
324
323
raise TypeError ("src is an object of unsupported class" )
325
324
@@ -389,11 +388,11 @@ def elements(self):
389
388
390
389
def dtype (self ):
391
390
"""
392
- Return the data type as a ctypes.c_int value.
391
+ Return the data type as a arrayfire.Dtype enum value.
393
392
"""
394
- dty = ct .c_int (f32 .value )
393
+ dty = ct .c_int (Dtype . f32 .value )
395
394
safe_call (backend .get ().af_get_type (ct .pointer (dty ), self .arr ))
396
- return dty
395
+ return Dtype ( dty . value )
397
396
398
397
def type (self ):
399
398
"""
0 commit comments