Skip to content

Commit ce1ae09

Browse files
committed
Work around for missing Enum class
1 parent f226b74 commit ce1ae09

File tree

10 files changed

+91
-77
lines changed

10 files changed

+91
-77
lines changed

arrayfire/arith.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def cast(a, dtype):
7474
array containing the values from `a` after converting to `dtype`.
7575
"""
7676
out=Array()
77-
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, dtype.value))
77+
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, Enum_value(dtype)))
7878
return out
7979

8080
def minof(lhs, rhs):

arrayfire/array.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def _create_array(buf, numdims, idims, dtype):
2222
out_arr = ct.c_void_p(0)
2323
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
2424
safe_call(backend.get().af_create_array(ct.pointer(out_arr), ct.c_void_p(buf),
25-
numdims, ct.pointer(c_dims), dtype.value))
25+
numdims, ct.pointer(c_dims), Enum_value(dtype)))
2626
return out_arr
2727

2828
def _create_empty_array(numdims, idims, dtype):
2929
out_arr = ct.c_void_p(0)
3030
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
3131
safe_call(backend.get().af_create_handle(ct.pointer(out_arr),
32-
numdims, ct.pointer(c_dims), dtype.value))
32+
numdims, ct.pointer(c_dims), Enum_value(dtype)))
3333
return out_arr
3434

3535
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
@@ -41,7 +41,7 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
4141
if isinstance(dtype, int):
4242
dtype = ct.c_int(dtype)
4343
elif isinstance(dtype, Dtype):
44-
dtype = ct.c_int(dtype.value)
44+
dtype = ct.c_int(Enum_value(dtype))
4545
else:
4646
raise TypeError("Invalid dtype")
4747

@@ -52,15 +52,16 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
5252
c_real = ct.c_double(val.real)
5353
c_imag = ct.c_double(val.imag)
5454

55-
if (dtype.value != Dtype.c32.value and dtype.value != Dtype.c64.value):
56-
dtype = Dtype.c32.value
55+
if (Enum_value(dtype) != Enum_value(Dtype) and Enum_value(dtype) != Enum_value(Dtype)):
56+
dtype = Enum_value(Dtype.c32)
5757

5858
safe_call(backend.get().af_constant_complex(ct.pointer(out), c_real, c_imag,
5959
4, ct.pointer(dims), dtype))
60-
elif dtype.value == Dtype.s64.value:
60+
61+
elif Enum_value(dtype) == Enum_value(Dtype.s64):
6162
c_val = ct.c_longlong(val.real)
6263
safe_call(backend.get().af_constant_long(ct.pointer(out), c_val, 4, ct.pointer(dims)))
63-
elif dtype.value == Dtype.u64.value:
64+
elif Enum_value(dtype) == Enum_value(Dtype.u64):
6465
c_val = ct.c_ulonglong(val.real)
6566
safe_call(backend.get().af_constant_ulong(ct.pointer(out), c_val, 4, ct.pointer(dims)))
6667
else:
@@ -78,7 +79,7 @@ def _binary_func(lhs, rhs, c_func):
7879
ldims = dim4_to_tuple(lhs.dims())
7980
rty = implicit_dtype(rhs, lhs.type())
8081
other = Array()
81-
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty.value)
82+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], Enum_value(rty))
8283
elif not isinstance(rhs, Array):
8384
raise TypeError("Invalid parameter to binary function")
8485

@@ -94,7 +95,7 @@ def _binary_funcr(lhs, rhs, c_func):
9495
rdims = dim4_to_tuple(rhs.dims())
9596
lty = implicit_dtype(lhs, rhs.type())
9697
other = Array()
97-
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty.value)
98+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], Enum_value(lty))
9899
elif not isinstance(lhs, Array):
99100
raise TypeError("Invalid parameter to binary function")
100101

@@ -350,7 +351,7 @@ def __init__(self, src=None, dims=(0,), dtype=None):
350351
if isinstance(dtype, str):
351352
type_char = dtype
352353
else:
353-
type_char = to_typecode[dtype.value]
354+
type_char = to_typecode[Enum_value(dtype)]
354355
else:
355356
type_char = None
356357

@@ -462,15 +463,15 @@ def dtype(self):
462463
"""
463464
Return the data type as a arrayfire.Dtype enum value.
464465
"""
465-
dty = ct.c_int(Dtype.f32.value)
466+
dty = ct.c_int(Enum_value(Dtype.f32))
466467
safe_call(backend.get().af_get_type(ct.pointer(dty), self.arr))
467-
return Dtype(dty.value)
468+
return to_dtype[typecodes[dty.value]]
468469

469470
def type(self):
470471
"""
471472
Return the data type as an int.
472473
"""
473-
return self.dtype().value
474+
return Enum_value(self.dtype())
474475

475476
def dims(self):
476477
"""

arrayfire/blas.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def matmul(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
5454
"""
5555
out = Array()
5656
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
57-
lhs_opts.value, rhs_opts.value))
57+
Enum_value(lhs_opts), Enum_value(rhs_opts)))
5858
return out
5959

6060
def matmulTN(lhs, rhs):
@@ -85,7 +85,7 @@ def matmulTN(lhs, rhs):
8585
"""
8686
out = Array()
8787
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
88-
MATPROP.TRANS.value, MATPROP.NONE.value))
88+
Enum_value(MATPROP.TRANS), Enum_value(MATPROP.NONE)))
8989
return out
9090

9191
def matmulNT(lhs, rhs):
@@ -116,7 +116,7 @@ def matmulNT(lhs, rhs):
116116
"""
117117
out = Array()
118118
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
119-
MATPROP.NONE.value, MATPROP.TRANS.value))
119+
Enum_value(MATPROP.NONE), Enum_value(MATPROP.TRANS)))
120120
return out
121121

122122
def matmulTT(lhs, rhs):
@@ -147,7 +147,7 @@ def matmulTT(lhs, rhs):
147147
"""
148148
out = Array()
149149
safe_call(backend.get().af_matmul(ct.pointer(out.arr), lhs.arr, rhs.arr,
150-
MATPROP.TRANS.value, MATPROP.TRANS.value))
150+
Enum_value(MATPROP.TRANS), Enum_value(MATPROP.TRANS)))
151151
return out
152152

153153
def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
@@ -188,5 +188,5 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
188188
"""
189189
out = Array()
190190
safe_call(backend.get().af_dot(ct.pointer(out.arr), lhs.arr, rhs.arr,
191-
lhs_opts.value, rhs_opts.value))
191+
Enum_value(lhs_opts), Enum_value(rhs_opts)))
192192
return out

arrayfire/data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
5252
"""
5353

5454
out = Array()
55-
out.arr = constant_array(val, d0, d1, d2, d3, dtype.value)
55+
out.arr = constant_array(val, d0, d1, d2, d3, Enum_value(dtype))
5656
return out
5757

5858
# Store builtin range function to be used later
@@ -116,7 +116,7 @@ def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=Dtype.f32):
116116
out = Array()
117117
dims = dim4(d0, d1, d2, d3)
118118

119-
safe_call(backend.get().af_range(ct.pointer(out.arr), 4, ct.pointer(dims), dim, dtype.value))
119+
safe_call(backend.get().af_range(ct.pointer(out.arr), 4, ct.pointer(dims), dim, Enum_value(dtype)))
120120
return out
121121

122122

@@ -182,7 +182,7 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32)
182182
tdims = dim4(td[0], td[1], td[2], td[3])
183183

184184
safe_call(backend.get().af_iota(ct.pointer(out.arr), 4, ct.pointer(dims),
185-
4, ct.pointer(tdims), dtype.value))
185+
4, ct.pointer(tdims), Enum_value(dtype)))
186186
return out
187187

188188
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
@@ -219,7 +219,7 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
219219
out = Array()
220220
dims = dim4(d0, d1, d2, d3)
221221

222-
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
222+
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), Enum_value(dtype)))
223223
return out
224224

225225
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
@@ -257,7 +257,7 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
257257
out = Array()
258258
dims = dim4(d0, d1, d2, d3)
259259

260-
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
260+
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), Enum_value(dtype)))
261261
return out
262262

263263
def set_seed(seed=0):
@@ -318,7 +318,7 @@ def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32):
318318
out = Array()
319319
dims = dim4(d0, d1, d2, d3)
320320

321-
safe_call(backend.get().af_identity(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
321+
safe_call(backend.get().af_identity(ct.pointer(out.arr), 4, ct.pointer(dims), Enum_value(dtype)))
322322
return out
323323

324324
def diag(a, num=0, extract=True):

arrayfire/graphics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, r, c, title, cmap):
2424
self.row = r
2525
self.col = c
2626
self.title = title if title is not None else ct.c_char_p()
27-
self.cmap = cmap.value
27+
self.cmap = Enum_value(cmap)
2828

2929
class Window(object):
3030
"""

arrayfire/image.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def resize(image, scale=None, odim0=None, odim1=None, method=INTERP.NEAREST):
123123
output = Array()
124124
safe_call(backend.get().af_resize(ct.pointer(output.arr),
125125
image.arr, ct.c_longlong(odim0),
126-
ct.c_longlong(odim1), method.value))
126+
ct.c_longlong(odim1), Enum_value(method)))
127127

128128
return output
129129

@@ -167,7 +167,7 @@ def transform(image, trans_mat, odim0 = 0, odim1 = 0, method=INTERP.NEAREST, is_
167167
safe_call(backend.get().af_transform(ct.pointer(output.arr),
168168
image.arr, trans_mat.arr,
169169
ct.c_longlong(odim0), ct.c_longlong(odim1),
170-
method.value, is_inverse))
170+
Enum_value(method), is_inverse))
171171
return output
172172

173173
def rotate(image, theta, is_crop = True, method = INTERP.NEAREST):
@@ -196,7 +196,7 @@ def rotate(image, theta, is_crop = True, method = INTERP.NEAREST):
196196
"""
197197
output = Array()
198198
safe_call(backend.get().af_rotate(ct.pointer(output.arr), image.arr,
199-
ct.c_double(theta), is_crop, method.value))
199+
ct.c_double(theta), is_crop, Enum_value(method)))
200200
return output
201201

202202
def translate(image, trans0, trans1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST):
@@ -238,7 +238,7 @@ def translate(image, trans0, trans1, odim0 = 0, odim1 = 0, method = INTERP.NEARE
238238
output = Array()
239239
safe_call(backend.get().af_translate(ct.pointer(output.arr),
240240
image.arr, trans0, trans1,
241-
ct.c_longlong(odim0), ct.c_longlong(odim1), method.value))
241+
ct.c_longlong(odim0), ct.c_longlong(odim1), Enum_value(method)))
242242
return output
243243

244244
def scale(image, scale0, scale1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST):
@@ -280,7 +280,7 @@ def scale(image, scale0, scale1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST):
280280
output = Array()
281281
safe_call(backend.get().af_scale(ct.pointer(output.arr),
282282
image.arr, ct.c_double(scale0), ct.c_double(scale1),
283-
ct.c_longlong(odim0), ct.c_longlong(odim1), method.value))
283+
ct.c_longlong(odim0), ct.c_longlong(odim1), Enum_value(method)))
284284
return output
285285

286286
def skew(image, skew0, skew1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST, is_inverse=True):
@@ -326,7 +326,7 @@ def skew(image, skew0, skew1, odim0 = 0, odim1 = 0, method = INTERP.NEAREST, is_
326326
safe_call(backend.get().af_skew(ct.pointer(output.arr),
327327
image.arr, ct.c_double(skew0), ct.c_double(skew1),
328328
ct.c_longlong(odim0), ct.c_longlong(odim1),
329-
method.value, is_inverse))
329+
Enum_value(method), is_inverse))
330330

331331
return output
332332

@@ -609,7 +609,7 @@ def medfilt(image, w0 = 3, w1 = 3, edge_pad = PAD.ZERO):
609609
output = Array()
610610
safe_call(backend.get().af_medfilt(ct.pointer(output.arr),
611611
image.arr, ct.c_longlong(w0),
612-
ct.c_longlong(w1), edge_pad.value))
612+
ct.c_longlong(w1), Enum_value(edge_pad)))
613613
return output
614614

615615
def minfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
@@ -641,7 +641,7 @@ def minfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
641641
output = Array()
642642
safe_call(backend.get().af_minfilt(ct.pointer(output.arr),
643643
image.arr, ct.c_longlong(w_len),
644-
ct.c_longlong(w_wid), edge_pad.value))
644+
ct.c_longlong(w_wid), Enum_value(edge_pad)))
645645
return output
646646

647647
def maxfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
@@ -673,7 +673,7 @@ def maxfilt(image, w_len = 3, w_wid = 3, edge_pad = PAD.ZERO):
673673
output = Array()
674674
safe_call(backend.get().af_maxfilt(ct.pointer(output.arr),
675675
image.arr, ct.c_longlong(w_len),
676-
ct.c_longlong(w_wid), edge_pad.value))
676+
ct.c_longlong(w_wid), Enum_value(edge_pad)))
677677
return output
678678

679679
def regions(image, conn = CONNECTIVITY.FOUR, out_type = Dtype.f32):
@@ -700,7 +700,7 @@ def regions(image, conn = CONNECTIVITY.FOUR, out_type = Dtype.f32):
700700
"""
701701
output = Array()
702702
safe_call(backend.get().af_regions(ct.pointer(output.arr), image.arr,
703-
conn.value, out_type.value))
703+
Enum_value(conn), Enum_value(out_type)))
704704
return output
705705

706706
def sobel_derivatives(image, w_len=3):
@@ -891,5 +891,5 @@ def color_space(image, to_type, from_type):
891891
"""
892892
output = Array()
893893
safe_call(backend.get().af_color_space(ct.pointer(output.arr), image.arr,
894-
to_type.value, from_type.value))
894+
Enum_value(to_type), Enum_value(from_type)))
895895
return output

arrayfire/lapack.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def solve(A, B, options=MATPROP.NONE):
202202
203203
"""
204204
X = Array()
205-
safe_call(backend.get().af_solve(ct.pointer(X.arr), A.arr, B.arr, options.value))
205+
safe_call(backend.get().af_solve(ct.pointer(X.arr), A.arr, B.arr, Enum_value(options)))
206206
return X
207207

208208
def solve_lu(A, P, B, options=MATPROP.NONE):
@@ -230,7 +230,7 @@ def solve_lu(A, P, B, options=MATPROP.NONE):
230230
231231
"""
232232
X = Array()
233-
safe_call(backend.get().af_solve_lu(ct.pointer(X.arr), A.arr, P.arr, B.arr, options.value))
233+
safe_call(backend.get().af_solve_lu(ct.pointer(X.arr), A.arr, P.arr, B.arr, Enum_value(options)))
234234
return X
235235

236236
def inverse(A, options=MATPROP.NONE):
@@ -260,7 +260,7 @@ def inverse(A, options=MATPROP.NONE):
260260
261261
"""
262262
AI = Array()
263-
safe_call(backend.get().af_inverse(ct.pointer(AI.arr), A.arr, options.value))
263+
safe_call(backend.get().af_inverse(ct.pointer(AI.arr), A.arr, Enum_value(options)))
264264
return AI
265265

266266
def rank(A, tol=1E-5):
@@ -336,6 +336,6 @@ def norm(A, norm_type=NORM.EUCLID, p=1.0, q=1.0):
336336
337337
"""
338338
res = ct.c_double(0)
339-
safe_call(backend.get().af_norm(ct.pointer(res), A.arr, norm_type.value,
339+
safe_call(backend.get().af_norm(ct.pointer(res), A.arr, Enum_value(norm_type),
340340
ct.c_double(p), ct.c_double(q)))
341341
return res.value

arrayfire/library.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,19 @@
1313

1414
import platform
1515
import ctypes as ct
16-
from enum import Enum
16+
17+
try:
18+
from enum import Enum
19+
20+
def Enum_value(val):
21+
return val.value
22+
23+
except:
24+
class Enum(object):
25+
pass
26+
27+
def Enum_value(val):
28+
return val
1729

1830
class _clibrary(object):
1931

0 commit comments

Comments
 (0)