Skip to content

Commit c8ec234

Browse files
committed
STYLE: Use Python Enum instead of ctypes.c_int for enums
1 parent 2005662 commit c8ec234

17 files changed

+318
-304
lines changed

arrayfire/arith.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,25 @@ def cast(a, dtype):
5656
----------
5757
a : af.Array
5858
Multi dimensional arrayfire array.
59-
dtype: ctypes.c_int.
59+
dtype: af.Dtype
6060
Must be one of the following:
61-
af.f32, af.f64, af.c32, af.c64
62-
af.s32, af.s64, af.u32, af.u64,
63-
af.b8, af.u8
64-
61+
- Dtype.f32 for float
62+
- Dtype.f64 for double
63+
- Dtype.b8 for bool
64+
- Dtype.u8 for unsigned char
65+
- Dtype.s32 for signed 32 bit integer
66+
- Dtype.u32 for unsigned 32 bit integer
67+
- Dtype.s64 for signed 64 bit integer
68+
- Dtype.u64 for unsigned 64 bit integer
69+
- Dtype.c32 for 32 bit complex number
70+
- Dtype.c64 for 64 bit complex number
6571
Returns
6672
--------
6773
out : af.Array
6874
array containing the values from `a` after converting to `dtype`.
6975
"""
7076
out=Array()
71-
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, dtype))
77+
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, dtype.value))
7278
return out
7379

7480
def minof(lhs, rhs):

arrayfire/array.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,26 @@ def _create_array(buf, numdims, idims, dtype):
2222
out_arr = ct.c_void_p(0)
2323
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
2424
safe_call(backend.get().af_create_array(ct.pointer(out_arr), ct.c_void_p(buf),
25-
numdims, ct.pointer(c_dims), dtype))
25+
numdims, ct.pointer(c_dims), dtype.value))
2626
return out_arr
2727

2828
def _create_empty_array(numdims, idims, dtype):
2929
out_arr = ct.c_void_p(0)
3030
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
3131
safe_call(backend.get().af_create_handle(ct.pointer(out_arr),
32-
numdims, ct.pointer(c_dims), dtype))
32+
numdims, ct.pointer(c_dims), dtype.value))
3333
return out_arr
3434

35-
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
35+
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
3636
"""
3737
Internal function to create a C array. Should not be used externall.
3838
"""
3939

4040
if not isinstance(dtype, ct.c_int):
4141
if isinstance(dtype, int):
4242
dtype = ct.c_int(dtype)
43+
elif isinstance(dtype, Dtype):
44+
dtype = ct.c_int(dtype.value)
4345
else:
4446
raise TypeError("Invalid dtype")
4547

@@ -50,15 +52,15 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
5052
c_real = ct.c_double(val.real)
5153
c_imag = ct.c_double(val.imag)
5254

53-
if (dtype != c32 and dtype != c64):
54-
dtype = c32
55+
if (dtype.value != Dtype.c32.value and dtype.value != Dtype.c64.value):
56+
dtype = Dtype.c32.value
5557

5658
safe_call(backend.get().af_constant_complex(ct.pointer(out), c_real, c_imag,
57-
4, ct.pointer(dims), dtype))
58-
elif dtype == s64:
59+
4, ct.pointer(dims), dtype))
60+
elif dtype.value == Dtype.s64.value:
5961
c_val = ct.c_longlong(val.real)
6062
safe_call(backend.get().af_constant_long(ct.pointer(out), c_val, 4, ct.pointer(dims)))
61-
elif dtype == u64:
63+
elif dtype.value == Dtype.u64.value:
6264
c_val = ct.c_ulonglong(val.real)
6365
safe_call(backend.get().af_constant_ulong(ct.pointer(out), c_val, 4, ct.pointer(dims)))
6466
else:
@@ -76,7 +78,7 @@ def _binary_func(lhs, rhs, c_func):
7678
ldims = dim4_to_tuple(lhs.dims())
7779
rty = implicit_dtype(rhs, lhs.type())
7880
other = Array()
79-
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
81+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty.value)
8082
elif not isinstance(rhs, Array):
8183
raise TypeError("Invalid parameter to binary function")
8284

@@ -92,7 +94,7 @@ def _binary_funcr(lhs, rhs, c_func):
9294
rdims = dim4_to_tuple(rhs.dims())
9395
lty = implicit_dtype(lhs, rhs.type())
9496
other = Array()
95-
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
97+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty.value)
9698
elif not isinstance(lhs, Array):
9799
raise TypeError("Invalid parameter to binary function")
98100

@@ -186,7 +188,7 @@ class Array(BaseArray):
186188
dims : optional: tuple of ints. default: (0,)
187189
- When using the default values of `dims`, the dims are caclulated as `len(src)`
188190
189-
dtype: optional: str or ctypes.c_int. default: None.
191+
dtype: optional: str or arrayfire.Dtype. default: None.
190192
- if str, must be one of the following:
191193
- 'f' for float
192194
- 'd' for double
@@ -198,18 +200,18 @@ class Array(BaseArray):
198200
- 'L' for unsigned 64 bit integer
199201
- 'F' for 32 bit complex number
200202
- 'D' for 64 bit complex number
201-
- if ctypes.c_int, must be one of the following:
202-
- f32 for float
203-
- f64 for double
204-
- b8 for bool
205-
- u8 for unsigned char
206-
- s32 for signed 32 bit integer
207-
- u32 for unsigned 32 bit integer
208-
- s64 for signed 64 bit integer
209-
- u64 for unsigned 64 bit integer
210-
- c32 for 32 bit complex number
211-
- c64 for 64 bit complex number
212-
- if None, f32 is assumed
203+
- if arrayfire.Dtype, must be one of the following:
204+
- Dtype.f32 for float
205+
- Dtype.f64 for double
206+
- Dtype.b8 for bool
207+
- Dtype.u8 for unsigned char
208+
- Dtype.s32 for signed 32 bit integer
209+
- Dtype.u32 for unsigned 32 bit integer
210+
- Dtype.s64 for signed 64 bit integer
211+
- Dtype.u64 for unsigned 64 bit integer
212+
- Dtype.c32 for 32 bit complex number
213+
- Dtype.c64 for 64 bit complex number
214+
- if None, Dtype.f32 is assumed
213215
214216
Attributes
215217
-----------
@@ -281,7 +283,6 @@ def __init__(self, src=None, dims=(0,), dtype=None):
281283
type_char = None
282284

283285
_type_char='f'
284-
dtype = f32
285286

286287
backend.lock()
287288

@@ -318,8 +319,6 @@ def __init__(self, src=None, dims=(0,), dtype=None):
318319

319320
_type_char = type_char
320321

321-
print(_type_char)
322-
323322
else:
324323
raise TypeError("src is an object of unsupported class")
325324

@@ -389,11 +388,11 @@ def elements(self):
389388

390389
def dtype(self):
391390
"""
392-
Return the data type as a ctypes.c_int value.
391+
Return the data type as a arrayfire.Dtype enum value.
393392
"""
394-
dty = ct.c_int(f32.value)
393+
dty = ct.c_int(Dtype.f32.value)
395394
safe_call(backend.get().af_get_type(ct.pointer(dty), self.arr))
396-
return dty
395+
return Dtype(dty.value)
397396

398397
def type(self):
399398
"""

arrayfire/blas.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,32 @@
1010
from .library import *
1111
from .array import *
1212

13-
def matmul(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
13+
def matmul(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
1414
out = Array()
1515
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
16-
lhs_opts, rhs_opts))
16+
lhs_opts.value, rhs_opts.value))
1717
return out
1818

1919
def matmulTN(lhs, rhs):
2020
out = Array()
2121
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
22-
AF_MAT_TRANS, AF_MAT_NONE))
22+
MATPROP.TRANS.value, MATPROP.NONE.value))
2323
return out
2424

2525
def matmulNT(lhs, rhs):
2626
out = Array()
2727
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
28-
AF_MAT_NONE, AF_MAT_TRANS))
28+
MATPROP.NONE.value, MATPROP.TRANS.value))
2929
return out
3030

3131
def matmulTT(lhs, rhs):
3232
out = Array()
3333
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
34-
AF_MAT_TRANS, AF_MAT_TRANS))
34+
MATPROP.TRANS.value, MATPROP.TRANS.value))
3535
return out
3636

37-
def dot(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
37+
def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
3838
out = Array()
3939
safe_call(backend.get().af_dot(ct.pointer(out.arr), lhs.arr, rhs.arr,
40-
lhs_opts, rhs_opts))
40+
lhs_opts.value, rhs_opts.value))
4141
return out

arrayfire/data.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,24 @@
1212
from .array import *
1313
from .util import *
1414

15-
def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
15+
def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
1616
out = Array()
17-
out.arr = constant_array(val, d0, d1, d2, d3, dtype)
17+
out.arr = constant_array(val, d0, d1, d2, d3, dtype.value)
1818
return out
1919

2020
# Store builtin range function to be used later
2121
_brange = range
2222

23-
def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
24-
25-
if not isinstance(dtype, ct.c_int):
26-
if isinstance(dtype, int):
27-
dtype = ct.c_int(dtype)
28-
else:
29-
raise TypeError("Invalid dtype")
23+
def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=Dtype.f32):
3024

3125
out = Array()
3226
dims = dim4(d0, d1, d2, d3)
3327

34-
safe_call(backend.get().af_range(ct.pointer(out.arr), 4, ct.pointer(dims), dim, dtype))
28+
safe_call(backend.get().af_range(ct.pointer(out.arr), 4, ct.pointer(dims), dim, dtype.value))
3529
return out
3630

3731

38-
def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
39-
if not isinstance(dtype, ct.c_int):
40-
if isinstance(dtype, int):
41-
dtype = ct.c_int(dtype)
42-
else:
43-
raise TypeError("Invalid dtype")
32+
def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32):
4433

4534
out = Array()
4635
dims = dim4(d0, d1, d2, d3)
@@ -52,35 +41,24 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
5241

5342
tdims = dim4(td[0], td[1], td[2], td[3])
5443

55-
safe_call(backend.get().af_iota(ct.pointer(out.arr), 4, ct.pointer(dims), 4, ct.pointer(tdims), dtype))
44+
safe_call(backend.get().af_iota(ct.pointer(out.arr), 4, ct.pointer(dims),
45+
4, ct.pointer(tdims), dtype.value))
5646
return out
5747

58-
def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
59-
60-
if not isinstance(dtype, ct.c_int):
61-
if isinstance(dtype, int):
62-
dtype = ct.c_int(dtype)
63-
else:
64-
raise TypeError("Invalid dtype")
48+
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
6549

6650
out = Array()
6751
dims = dim4(d0, d1, d2, d3)
6852

69-
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype))
53+
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
7054
return out
7155

72-
def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
73-
74-
if not isinstance(dtype, ct.c_int):
75-
if isinstance(dtype, int):
76-
dtype = ct.c_int(dtype)
77-
else:
78-
raise TypeError("Invalid dtype")
56+
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
7957

8058
out = Array()
8159
dims = dim4(d0, d1, d2, d3)
8260

83-
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype))
61+
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
8462
return out
8563

8664
def set_seed(seed=0):
@@ -91,18 +69,12 @@ def get_seed():
9169
safe_call(backend.get().af_get_seed(ct.pointer(seed)))
9270
return seed.value
9371

94-
def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
95-
96-
if not isinstance(dtype, ct.c_int):
97-
if isinstance(dtype, int):
98-
dtype = ct.c_int(dtype)
99-
else:
100-
raise TypeError("Invalid dtype")
72+
def identity(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
10173

10274
out = Array()
10375
dims = dim4(d0, d1, d2, d3)
10476

105-
safe_call(backend.get().af_identity(ct.pointer(out.arr), 4, ct.pointer(dims), dtype))
77+
safe_call(backend.get().af_identity(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
10678
return out
10779

10880
def diag(a, num=0, extract=True):

arrayfire/graphics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, width=None, height=None, title=None):
2828
self._r = -1
2929
self._c = -1
3030
self._wnd = ct.c_longlong(0)
31-
self._cmap = AF_COLORMAP_DEFAULT
31+
self._cmap = COLORMAP.DEFAULT
3232

3333
_width = 1280 if width is None else width
3434
_height = 720 if height is None else height
@@ -37,7 +37,8 @@ def __init__(self, width=None, height=None, title=None):
3737
_title = _title.encode("ascii")
3838

3939
safe_call(backend.get().af_create_window(ct.pointer(self._wnd),
40-
ct.c_int(_width), ct.c_int(_height), ct.c_char_p(_title)))
40+
ct.c_int(_width), ct.c_int(_height),
41+
ct.c_char_p(_title)))
4142

4243
def __del__(self):
4344
safe_call(backend.get().af_destroy_window(self._wnd))

0 commit comments

Comments
 (0)