Skip to content

Commit fed169f

Browse files
committed
Implement a cleaner way for the Enum class
- Revert "Work around for missing Enum class" - This reverts commit ce1ae09.
1 parent 2a52463 commit fed169f

File tree

11 files changed

+162
-164
lines changed

11 files changed

+162
-164
lines changed

arrayfire/arith.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def cast(a, dtype):
7474
array containing the values from `a` after converting to `dtype`.
7575
"""
7676
out=Array()
77-
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, Enum_value(dtype)))
77+
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, dtype.value))
7878
return out
7979

8080
def minof(lhs, rhs):

arrayfire/array.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def _create_array(buf, numdims, idims, dtype):
2222
out_arr = ct.c_void_p(0)
2323
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
2424
safe_call(backend.get().af_create_array(ct.pointer(out_arr), ct.c_void_p(buf),
25-
numdims, ct.pointer(c_dims), Enum_value(dtype)))
25+
numdims, ct.pointer(c_dims), dtype.value))
2626
return out_arr
2727

2828
def _create_empty_array(numdims, idims, dtype):
2929
out_arr = ct.c_void_p(0)
3030
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
3131
safe_call(backend.get().af_create_handle(ct.pointer(out_arr),
32-
numdims, ct.pointer(c_dims), Enum_value(dtype)))
32+
numdims, ct.pointer(c_dims), dtype.value))
3333
return out_arr
3434

3535
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
@@ -41,7 +41,7 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
4141
if isinstance(dtype, int):
4242
dtype = ct.c_int(dtype)
4343
elif isinstance(dtype, Dtype):
44-
dtype = ct.c_int(Enum_value(dtype))
44+
dtype = ct.c_int(dtype.value)
4545
else:
4646
raise TypeError("Invalid dtype")
4747

@@ -52,16 +52,15 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
5252
c_real = ct.c_double(val.real)
5353
c_imag = ct.c_double(val.imag)
5454

55-
if (Enum_value(dtype) != Enum_value(Dtype) and Enum_value(dtype) != Enum_value(Dtype)):
56-
dtype = Enum_value(Dtype.c32)
55+
if (dtype.value != Dtype.c32.value and dtype.value != Dtype.c64.value):
56+
dtype = Dtype.c32.value
5757

5858
safe_call(backend.get().af_constant_complex(ct.pointer(out), c_real, c_imag,
5959
4, ct.pointer(dims), dtype))
60-
61-
elif Enum_value(dtype) == Enum_value(Dtype.s64):
60+
elif dtype.value == Dtype.s64.value:
6261
c_val = ct.c_longlong(val.real)
6362
safe_call(backend.get().af_constant_long(ct.pointer(out), c_val, 4, ct.pointer(dims)))
64-
elif Enum_value(dtype) == Enum_value(Dtype.u64):
63+
elif dtype.value == Dtype.u64.value:
6564
c_val = ct.c_ulonglong(val.real)
6665
safe_call(backend.get().af_constant_ulong(ct.pointer(out), c_val, 4, ct.pointer(dims)))
6766
else:
@@ -79,7 +78,7 @@ def _binary_func(lhs, rhs, c_func):
7978
ldims = dim4_to_tuple(lhs.dims())
8079
rty = implicit_dtype(rhs, lhs.type())
8180
other = Array()
82-
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], Enum_value(rty))
81+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty.value)
8382
elif not isinstance(rhs, Array):
8483
raise TypeError("Invalid parameter to binary function")
8584

@@ -95,7 +94,7 @@ def _binary_funcr(lhs, rhs, c_func):
9594
rdims = dim4_to_tuple(rhs.dims())
9695
lty = implicit_dtype(lhs, rhs.type())
9796
other = Array()
98-
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], Enum_value(lty))
97+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty.value)
9998
elif not isinstance(lhs, Array):
10099
raise TypeError("Invalid parameter to binary function")
101100

@@ -351,7 +350,7 @@ def __init__(self, src=None, dims=(0,), dtype=None):
351350
if isinstance(dtype, str):
352351
type_char = dtype
353352
else:
354-
type_char = to_typecode[Enum_value(dtype)]
353+
type_char = to_typecode[dtype.value]
355354
else:
356355
type_char = None
357356

@@ -463,15 +462,15 @@ def dtype(self):
463462
"""
464463
Return the data type as a arrayfire.Dtype enum value.
465464
"""
466-
dty = ct.c_int(Enum_value(Dtype.f32))
465+
dty = ct.c_int(Dtype.f32.value)
467466
safe_call(backend.get().af_get_type(ct.pointer(dty), self.arr))
468467
return to_dtype[typecodes[dty.value]]
469468

470469
def type(self):
471470
"""
472471
Return the data type as an int.
473472
"""
474-
return Enum_value(self.dtype())
473+
return self.dtype().value
475474

476475
def dims(self):
477476
"""

arrayfire/blas.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def matmul(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
5454
"""
5555
out = Array()
5656
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
57-
Enum_value(lhs_opts), Enum_value(rhs_opts)))
57+
lhs_opts.value, rhs_opts.value))
5858
return out
5959

6060
def matmulTN(lhs, rhs):
@@ -85,7 +85,7 @@ def matmulTN(lhs, rhs):
8585
"""
8686
out = Array()
8787
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
88-
Enum_value(MATPROP.TRANS), Enum_value(MATPROP.NONE)))
88+
MATPROP.TRANS.value, MATPROP.NONE.value))
8989
return out
9090

9191
def matmulNT(lhs, rhs):
@@ -116,7 +116,7 @@ def matmulNT(lhs, rhs):
116116
"""
117117
out = Array()
118118
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
119-
Enum_value(MATPROP.NONE), Enum_value(MATPROP.TRANS)))
119+
MATPROP.NONE.value, MATPROP.TRANS.value))
120120
return out
121121

122122
def matmulTT(lhs, rhs):
@@ -147,7 +147,7 @@ def matmulTT(lhs, rhs):
147147
"""
148148
out = Array()
149149
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
150-
Enum_value(MATPROP.TRANS), Enum_value(MATPROP.TRANS)))
150+
MATPROP.TRANS.value, MATPROP.TRANS.value))
151151
return out
152152

153153
def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
@@ -188,5 +188,5 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
188188
"""
189189
out = Array()
190190
safe_call(backend.get().af_dot(ct.pointer(out.arr), lhs.arr, rhs.arr,
191-
Enum_value(lhs_opts), Enum_value(rhs_opts)))
191+
lhs_opts.value, rhs_opts.value))
192192
return out

arrayfire/data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
5252
"""
5353

5454
out = Array()
55-
out.arr = constant_array(val, d0, d1, d2, d3, Enum_value(dtype))
55+
out.arr = constant_array(val, d0, d1, d2, d3, dtype.value)
5656
return out
5757

5858
# Store builtin range function to be used later
@@ -116,7 +116,7 @@ def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=Dtype.f32):
116116
out = Array()
117117
dims = dim4(d0, d1, d2, d3)
118118

119-
safe_call(backend.get().af_range(ct.pointer(out.arr), 4, ct.pointer(dims), dim, Enum_value(dtype)))
119+
safe_call(backend.get().af_range(ct.pointer(out.arr), 4, ct.pointer(dims), dim, dtype.value))
120120
return out
121121

122122

@@ -182,7 +182,7 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32)
182182
tdims = dim4(td[0], td[1], td[2], td[3])
183183

184184
safe_call(backend.get().af_iota(ct.pointer(out.arr), 4, ct.pointer(dims),
185-
4, ct.pointer(tdims), Enum_value(dtype)))
185+
4, ct.pointer(tdims), dtype.value))
186186
return out
187187

188188
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
@@ -219,7 +219,7 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
219219
out = Array()
220220
dims = dim4(d0, d1, d2, d3)
221221

222-
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), Enum_value(dtype)))
222+
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
223223
return out
224224

225225
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
@@ -257,7 +257,7 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
257257
out = Array()
258258
dims = dim4(d0, d1, d2, d3)
259259

260-
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), Enum_value(dtype)))
260+
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
261261
return out
262262

263263
def set_seed(seed=0):
@@ -318,7 +318,7 @@ def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32):
318318
out = Array()
319319
dims = dim4(d0, d1, d2, d3)
320320

321-
safe_call(backend.get().af_identity(ct.pointer(out.arr), 4, ct.pointer(dims), Enum_value(dtype)))
321+
safe_call(backend.get().af_identity(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
322322
return out
323323

324324
def diag(a, num=0, extract=True):

arrayfire/graphics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, r, c, title, cmap):
2424
self.row = r
2525
self.col = c
2626
self.title = title if title is not None else ct.c_char_p()
27-
self.cmap = Enum_value(cmap)
27+
self.cmap = cmap.value
2828

2929
class Window(object):
3030
"""

arrayfire/image.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def resize(image, scale=None, odim0=None, odim1=None, method=INTERP.NEAREST):
123123
output = Array()
124124
safe_call(backend.get().af_resize(ct.pointer(output.arr),
125125
image.arr, ct.c_longlong(odim0),
126-
ct.c_longlong(odim1), Enum_value(method)))
126+
ct.c_longlong(odim1), method.value))
127127

128128
return output
129129

@@ -167,7 +167,7 @@ def transform(image, trans_mat, odim0 = 0, odim1 = 0, method=INTERP.NEAREST, is_
167167
safe_call(backend.get().af_transform(ct.pointer(output.arr),
168168
image.arr, trans_mat.arr,
169169
ct.c_longlong(odim0), ct.c_longlong(odim1),
170-
Enum_value(method), is_inverse))
170+
method.value, is_inverse))
171171
return output
172172

173173
def rotate(image, theta, is_crop = True, method = INTERP.NEAREST):
@@ -196,7 +196,7 @@ def rotate(image, theta, is_crop = True, method = INTERP.NEAREST):
196196
"""
197197
output = Array()
198198
safe_call(backend.get().af_rotate(ct.pointer(output.arr), image.arr,
199-
ct.c_double(theta), is_crop, Enum_value(method)))
199+
ct.c_double(theta), is_crop, method.value))
200200
return output
201201

202202
def translate(image, trans0, trans1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST):
@@ -238,7 +238,7 @@ def translate(image, trans0, trans1, odim0 = 0, odim1 = 0, method = INTERP.NEARE
238238
output = Array()
239239
safe_call(backend.get().af_translate(ct.pointer(output.arr),
240240
image.arr, trans0, trans1,
241-
ct.c_longlong(odim0), ct.c_longlong(odim1), Enum_value(method)))
241+
ct.c_longlong(odim0), ct.c_longlong(odim1), method.value))
242242
return output
243243

244244
def scale(image, scale0, scale1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST):
@@ -280,7 +280,7 @@ def scale(image, scale0, scale1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST):
280280
output = Array()
281281
safe_call(backend.get().af_scale(ct.pointer(output.arr),
282282
image.arr, ct.c_double(scale0), ct.c_double(scale1),
283-
ct.c_longlong(odim0), ct.c_longlong(odim1), Enum_value(method)))
283+
ct.c_longlong(odim0), ct.c_longlong(odim1), method.value))
284284
return output
285285

286286
def skew(image, skew0, skew1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST, is_inverse=True):
@@ -326,7 +326,7 @@ def skew(image, skew0, skew1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST, is_
326326
safe_call(backend.get().af_skew(ct.pointer(output.arr),
327327
image.arr, ct.c_double(skew0), ct.c_double(skew1),
328328
ct.c_longlong(odim0), ct.c_longlong(odim1),
329-
Enum_value(method), is_inverse))
329+
method.value, is_inverse))
330330

331331
return output
332332

@@ -609,7 +609,7 @@ def medfilt(image, w0 = 3, w1 = 3, edge_pad = PAD.ZERO):
609609
output = Array()
610610
safe_call(backend.get().af_medfilt(ct.pointer(output.arr),
611611
image.arr, ct.c_longlong(w0),
612-
ct.c_longlong(w1), Enum_value(edge_pad)))
612+
ct.c_longlong(w1), edge_pad.value))
613613
return output
614614

615615
def minfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
@@ -641,7 +641,7 @@ def minfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
641641
output = Array()
642642
safe_call(backend.get().af_minfilt(ct.pointer(output.arr),
643643
image.arr, ct.c_longlong(w_len),
644-
ct.c_longlong(w_wid), Enum_value(edge_pad)))
644+
ct.c_longlong(w_wid), edge_pad.value))
645645
return output
646646

647647
def maxfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
@@ -673,7 +673,7 @@ def maxfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
673673
output = Array()
674674
safe_call(backend.get().af_maxfilt(ct.pointer(output.arr),
675675
image.arr, ct.c_longlong(w_len),
676-
ct.c_longlong(w_wid), Enum_value(edge_pad)))
676+
ct.c_longlong(w_wid), edge_pad.value))
677677
return output
678678

679679
def regions(image, conn = CONNECTIVITY.FOUR, out_type = Dtype.f32):
@@ -700,7 +700,7 @@ def regions(image, conn = CONNECTIVITY.FOUR, out_type = Dtype.f32):
700700
"""
701701
output = Array()
702702
safe_call(backend.get().af_regions(ct.pointer(output.arr), image.arr,
703-
Enum_value(conn), Enum_value(out_type)))
703+
conn.value, out_type.value))
704704
return output
705705

706706
def sobel_derivatives(image, w_len=3):
@@ -891,5 +891,5 @@ def color_space(image, to_type, from_type):
891891
"""
892892
output = Array()
893893
safe_call(backend.get().af_color_space(ct.pointer(output.arr), image.arr,
894-
Enum_value(to_type), Enum_value(from_type)))
894+
to_type.value, from_type.value))
895895
return output

arrayfire/index.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ def __init__ (self, idx):
195195

196196
arr = ct.c_void_p(0)
197197

198-
if (Enum_value(idx.dtype()) ==
199-
Enum_value(Dtype.b8)):
198+
if (idx.type() == Dtype.b8.value):
200199
safe_call(backend.get().af_where(ct.pointer(arr), idx.arr))
201200
else:
202201
safe_call(backend.get().af_retain_array(ct.pointer(arr), idx.arr))

arrayfire/lapack.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def solve(A, B, options=MATPROP.NONE):
202202
203203
"""
204204
X = Array()
205-
safe_call(backend.get().af_solve(ct.pointer(X.arr), A.arr, B.arr, Enum_value(options)))
205+
safe_call(backend.get().af_solve(ct.pointer(X.arr), A.arr, B.arr, options.value))
206206
return X
207207

208208
def solve_lu(A, P, B, options=MATPROP.NONE):
@@ -230,7 +230,7 @@ def solve_lu(A, P, B, options=MATPROP.NONE):
230230
231231
"""
232232
X = Array()
233-
safe_call(backend.get().af_solve_lu(ct.pointer(X.arr), A.arr, P.arr, B.arr, Enum_value(options)))
233+
safe_call(backend.get().af_solve_lu(ct.pointer(X.arr), A.arr, P.arr, B.arr, options.value))
234234
return X
235235

236236
def inverse(A, options=MATPROP.NONE):
@@ -260,7 +260,7 @@ def inverse(A, options=MATPROP.NONE):
260260
261261
"""
262262
AI = Array()
263-
safe_call(backend.get().af_inverse(ct.pointer(AI.arr), A.arr, Enum_value(options)))
263+
safe_call(backend.get().af_inverse(ct.pointer(AI.arr), A.arr, options.value))
264264
return AI
265265

266266
def rank(A, tol=1E-5):
@@ -336,6 +336,6 @@ def norm(A, norm_type=NORM.EUCLID, p=1.0, q=1.0):
336336
337337
"""
338338
res = ct.c_double(0)
339-
safe_call(backend.get().af_norm(ct.pointer(res), A.arr, Enum_value(norm_type),
339+
safe_call(backend.get().af_norm(ct.pointer(res), A.arr, norm_type.value,
340340
ct.c_double(p), ct.c_double(q)))
341341
return res.value

0 commit comments

Comments
 (0)