Skip to content

Commit a5394fe

Browse files
committed
Adding runtime selection of arrayfire backends using af.backend.set(name)
- Backend is locked after array creation - Backend can be unlocked in an unsafe manner by passing unsafe=True flag
1 parent a33701c commit a5394fe

File tree

16 files changed

+332
-306
lines changed

16 files changed

+332
-306
lines changed

arrayfire/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
del arith_unary_func
5050
del arith_binary_func
5151
del brange
52-
del load_backend
5352
del dim4_tuple
5453
del is_number
5554
del to_str

arrayfire/algorithm.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,57 +25,57 @@ def reduce_all(a, c_func):
2525

2626
def sum(a, dim=None):
2727
if dim is not None:
28-
return parallel_dim(a, dim, clib.af_sum)
28+
return parallel_dim(a, dim, backend.get().af_sum)
2929
else:
30-
return reduce_all(a, clib.af_sum_all)
30+
return reduce_all(a, backend.get().af_sum_all)
3131

3232
def product(a, dim=None):
3333
if dim is not None:
34-
return parallel_dim(a, dim, clib.af_product)
34+
return parallel_dim(a, dim, backend.get().af_product)
3535
else:
36-
return reduce_all(a, clib.af_product_all)
36+
return reduce_all(a, backend.get().af_product_all)
3737

3838
def min(a, dim=None):
3939
if dim is not None:
40-
return parallel_dim(a, dim, clib.af_min)
40+
return parallel_dim(a, dim, backend.get().af_min)
4141
else:
42-
return reduce_all(a, clib.af_min_all)
42+
return reduce_all(a, backend.get().af_min_all)
4343

4444
def max(a, dim=None):
4545
if dim is not None:
46-
return parallel_dim(a, dim, clib.af_max)
46+
return parallel_dim(a, dim, backend.get().af_max)
4747
else:
48-
return reduce_all(a, clib.af_max_all)
48+
return reduce_all(a, backend.get().af_max_all)
4949

5050
def all_true(a, dim=None):
5151
if dim is not None:
52-
return parallel_dim(a, dim, clib.af_all_true)
52+
return parallel_dim(a, dim, backend.get().af_all_true)
5353
else:
54-
return reduce_all(a, clib.af_all_true_all)
54+
return reduce_all(a, backend.get().af_all_true_all)
5555

5656
def any_true(a, dim=None):
5757
if dim is not None:
58-
return parallel_dim(a, dim, clib.af_any_true)
58+
return parallel_dim(a, dim, backend.get().af_any_true)
5959
else:
60-
return reduce_all(a, clib.af_any_true_all)
60+
return reduce_all(a, backend.get().af_any_true_all)
6161

6262
def count(a, dim=None):
6363
if dim is not None:
64-
return parallel_dim(a, dim, clib.af_count)
64+
return parallel_dim(a, dim, backend.get().af_count)
6565
else:
66-
return reduce_all(a, clib.af_count_all)
66+
return reduce_all(a, backend.get().af_count_all)
6767

6868
def imin(a, dim=None):
6969
if dim is not None:
7070
out = Array()
7171
idx = Array()
72-
safe_call(clib.af_imin(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr, ct.c_int(dim)))
72+
safe_call(backend.get().af_imin(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr, ct.c_int(dim)))
7373
return out,idx
7474
else:
7575
real = ct.c_double(0)
7676
imag = ct.c_double(0)
7777
idx = ct.c_uint(0)
78-
safe_call(clib.af_imin_all(ct.pointer(real), ct.pointer(imag), ct.pointer(idx), a.arr))
78+
safe_call(backend.get().af_imin_all(ct.pointer(real), ct.pointer(imag), ct.pointer(idx), a.arr))
7979
real = real.value
8080
imag = imag.value
8181
val = real if imag == 0 else real + imag * 1j
@@ -85,63 +85,63 @@ def imax(a, dim=None):
8585
if dim is not None:
8686
out = Array()
8787
idx = Array()
88-
safe_call(clib.af_imax(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr, ct.c_int(dim)))
88+
safe_call(backend.get().af_imax(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr, ct.c_int(dim)))
8989
return out,idx
9090
else:
9191
real = ct.c_double(0)
9292
imag = ct.c_double(0)
9393
idx = ct.c_uint(0)
94-
safe_call(clib.af_imax_all(ct.pointer(real), ct.pointer(imag), ct.pointer(idx), a.arr))
94+
safe_call(backend.get().af_imax_all(ct.pointer(real), ct.pointer(imag), ct.pointer(idx), a.arr))
9595
real = real.value
9696
imag = imag.value
9797
val = real if imag == 0 else real + imag * 1j
9898
return val,idx.value
9999

100100

101101
def accum(a, dim=0):
102-
return parallel_dim(a, dim, clib.af_accum)
102+
return parallel_dim(a, dim, backend.get().af_accum)
103103

104104
def where(a):
105105
out = Array()
106-
safe_call(clib.af_where(ct.pointer(out.arr), a.arr))
106+
safe_call(backend.get().af_where(ct.pointer(out.arr), a.arr))
107107
return out
108108

109109
def diff1(a, dim=0):
110-
return parallel_dim(a, dim, clib.af_diff1)
110+
return parallel_dim(a, dim, backend.get().af_diff1)
111111

112112
def diff2(a, dim=0):
113-
return parallel_dim(a, dim, clib.af_diff2)
113+
return parallel_dim(a, dim, backend.get().af_diff2)
114114

115115
def sort(a, dim=0, is_ascending=True):
116116
out = Array()
117-
safe_call(clib.af_sort(ct.pointer(out.arr), a.arr, ct.c_uint(dim), ct.c_bool(is_ascending)))
117+
safe_call(backend.get().af_sort(ct.pointer(out.arr), a.arr, ct.c_uint(dim), ct.c_bool(is_ascending)))
118118
return out
119119

120120
def sort_index(a, dim=0, is_ascending=True):
121121
out = Array()
122122
idx = Array()
123-
safe_call(clib.af_sort_index(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr,
123+
safe_call(backend.get().af_sort_index(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr,
124124
ct.c_uint(dim), ct.c_bool(is_ascending)))
125125
return out,idx
126126

127127
def sort_by_key(iv, ik, dim=0, is_ascending=True):
128128
ov = Array()
129129
ok = Array()
130-
safe_call(clib.af_sort_by_key(ct.pointer(ov.arr), ct.pointer(ok.arr),
130+
safe_call(backend.get().af_sort_by_key(ct.pointer(ov.arr), ct.pointer(ok.arr),
131131
iv.arr, ik.arr, ct.c_uint(dim), ct.c_bool(is_ascending)))
132132
return ov,ok
133133

134134
def set_unique(a, is_sorted=False):
135135
out = Array()
136-
safe_call(clib.af_set_unique(ct.pointer(out.arr), a.arr, ct.c_bool(is_sorted)))
136+
safe_call(backend.get().af_set_unique(ct.pointer(out.arr), a.arr, ct.c_bool(is_sorted)))
137137
return out
138138

139139
def set_union(a, b, is_unique=False):
140140
out = Array()
141-
safe_call(clib.af_set_union(ct.pointer(out.arr), a.arr, b.arr, ct.c_bool(is_unique)))
141+
safe_call(backend.get().af_set_union(ct.pointer(out.arr), a.arr, b.arr, ct.c_bool(is_unique)))
142142
return out
143143

144144
def set_intersect(a, b, is_unique=False):
145145
out = Array()
146-
safe_call(clib.af_set_intersect(ct.pointer(out.arr), a.arr, b.arr, ct.c_bool(is_unique)))
146+
safe_call(backend.get().af_set_intersect(ct.pointer(out.arr), a.arr, b.arr, ct.c_bool(is_unique)))
147147
return out

arrayfire/arith.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -46,149 +46,149 @@ def arith_unary_func(a, c_func):
4646

4747
def cast(a, dtype=f32):
4848
out=Array()
49-
safe_call(clib.af_cast(ct.pointer(out.arr), a.arr, dtype))
49+
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, dtype))
5050
return out
5151

5252
def minof(lhs, rhs):
53-
return arith_binary_func(lhs, rhs, clib.af_minof)
53+
return arith_binary_func(lhs, rhs, backend.get().af_minof)
5454

5555
def maxof(lhs, rhs):
56-
return arith_binary_func(lhs, rhs, clib.af_maxof)
56+
return arith_binary_func(lhs, rhs, backend.get().af_maxof)
5757

5858
def rem(lhs, rhs):
59-
return arith_binary_func(lhs, rhs, clib.af_rem)
59+
return arith_binary_func(lhs, rhs, backend.get().af_rem)
6060

6161
def abs(a):
62-
return arith_unary_func(a, clib.af_abs)
62+
return arith_unary_func(a, backend.get().af_abs)
6363

6464
def arg(a):
65-
return arith_unary_func(a, clib.af_arg)
65+
return arith_unary_func(a, backend.get().af_arg)
6666

6767
def sign(a):
68-
return arith_unary_func(a, clib.af_sign)
68+
return arith_unary_func(a, backend.get().af_sign)
6969

7070
def round(a):
71-
return arith_unary_func(a, clib.af_round)
71+
return arith_unary_func(a, backend.get().af_round)
7272

7373
def trunc(a):
74-
return arith_unary_func(a, clib.af_trunc)
74+
return arith_unary_func(a, backend.get().af_trunc)
7575

7676
def floor(a):
77-
return arith_unary_func(a, clib.af_floor)
77+
return arith_unary_func(a, backend.get().af_floor)
7878

7979
def ceil(a):
80-
return arith_unary_func(a, clib.af_ceil)
80+
return arith_unary_func(a, backend.get().af_ceil)
8181

8282
def hypot(lhs, rhs):
83-
return arith_binary_func(lhs, rhs, clib.af_hypot)
83+
return arith_binary_func(lhs, rhs, backend.get().af_hypot)
8484

8585
def sin(a):
86-
return arith_unary_func(a, clib.af_sin)
86+
return arith_unary_func(a, backend.get().af_sin)
8787

8888
def cos(a):
89-
return arith_unary_func(a, clib.af_cos)
89+
return arith_unary_func(a, backend.get().af_cos)
9090

9191
def tan(a):
92-
return arith_unary_func(a, clib.af_tan)
92+
return arith_unary_func(a, backend.get().af_tan)
9393

9494
def asin(a):
95-
return arith_unary_func(a, clib.af_asin)
95+
return arith_unary_func(a, backend.get().af_asin)
9696

9797
def acos(a):
98-
return arith_unary_func(a, clib.af_acos)
98+
return arith_unary_func(a, backend.get().af_acos)
9999

100100
def atan(a):
101-
return arith_unary_func(a, clib.af_atan)
101+
return arith_unary_func(a, backend.get().af_atan)
102102

103103
def atan2(lhs, rhs):
104-
return arith_binary_func(lhs, rhs, clib.af_atan2)
104+
return arith_binary_func(lhs, rhs, backend.get().af_atan2)
105105

106106
def cplx(lhs, rhs=None):
107107
if rhs is None:
108-
return arith_unary_func(lhs, clib.af_cplx)
108+
return arith_unary_func(lhs, backend.get().af_cplx)
109109
else:
110-
return arith_binary_func(lhs, rhs, clib.af_cplx2)
110+
return arith_binary_func(lhs, rhs, backend.get().af_cplx2)
111111

112112
def real(lhs):
113-
return arith_unary_func(lhs, clib.af_real)
113+
return arith_unary_func(lhs, backend.get().af_real)
114114

115115
def imag(lhs):
116-
return arith_unary_func(lhs, clib.af_imag)
116+
return arith_unary_func(lhs, backend.get().af_imag)
117117

118118
def conjg(lhs):
119-
return arith_unary_func(lhs, clib.af_conjg)
119+
return arith_unary_func(lhs, backend.get().af_conjg)
120120

121121
def sinh(a):
122-
return arith_unary_func(a, clib.af_sinh)
122+
return arith_unary_func(a, backend.get().af_sinh)
123123

124124
def cosh(a):
125-
return arith_unary_func(a, clib.af_cosh)
125+
return arith_unary_func(a, backend.get().af_cosh)
126126

127127
def tanh(a):
128-
return arith_unary_func(a, clib.af_tanh)
128+
return arith_unary_func(a, backend.get().af_tanh)
129129

130130
def asinh(a):
131-
return arith_unary_func(a, clib.af_asinh)
131+
return arith_unary_func(a, backend.get().af_asinh)
132132

133133
def acosh(a):
134-
return arith_unary_func(a, clib.af_acosh)
134+
return arith_unary_func(a, backend.get().af_acosh)
135135

136136
def atanh(a):
137-
return arith_unary_func(a, clib.af_atanh)
137+
return arith_unary_func(a, backend.get().af_atanh)
138138

139139
def root(lhs, rhs):
140-
return arith_binary_func(lhs, rhs, clib.af_root)
140+
return arith_binary_func(lhs, rhs, backend.get().af_root)
141141

142142
def pow(lhs, rhs):
143-
return arith_binary_func(lhs, rhs, clib.af_pow)
143+
return arith_binary_func(lhs, rhs, backend.get().af_pow)
144144

145145
def pow2(a):
146-
return arith_unary_func(a, clib.af_pow2)
146+
return arith_unary_func(a, backend.get().af_pow2)
147147

148148
def exp(a):
149-
return arith_unary_func(a, clib.af_exp)
149+
return arith_unary_func(a, backend.get().af_exp)
150150

151151
def expm1(a):
152-
return arith_unary_func(a, clib.af_expm1)
152+
return arith_unary_func(a, backend.get().af_expm1)
153153

154154
def erf(a):
155-
return arith_unary_func(a, clib.af_erf)
155+
return arith_unary_func(a, backend.get().af_erf)
156156

157157
def erfc(a):
158-
return arith_unary_func(a, clib.af_erfc)
158+
return arith_unary_func(a, backend.get().af_erfc)
159159

160160
def log(a):
161-
return arith_unary_func(a, clib.af_log)
161+
return arith_unary_func(a, backend.get().af_log)
162162

163163
def log1p(a):
164-
return arith_unary_func(a, clib.af_log1p)
164+
return arith_unary_func(a, backend.get().af_log1p)
165165

166166
def log10(a):
167-
return arith_unary_func(a, clib.af_log10)
167+
return arith_unary_func(a, backend.get().af_log10)
168168

169169
def log2(a):
170-
return arith_unary_func(a, clib.af_log2)
170+
return arith_unary_func(a, backend.get().af_log2)
171171

172172
def sqrt(a):
173-
return arith_unary_func(a, clib.af_sqrt)
173+
return arith_unary_func(a, backend.get().af_sqrt)
174174

175175
def cbrt(a):
176-
return arith_unary_func(a, clib.af_cbrt)
176+
return arith_unary_func(a, backend.get().af_cbrt)
177177

178178
def factorial(a):
179-
return arith_unary_func(a, clib.af_factorial)
179+
return arith_unary_func(a, backend.get().af_factorial)
180180

181181
def tgamma(a):
182-
return arith_unary_func(a, clib.af_tgamma)
182+
return arith_unary_func(a, backend.get().af_tgamma)
183183

184184
def lgamma(a):
185-
return arith_unary_func(a, clib.af_lgamma)
185+
return arith_unary_func(a, backend.get().af_lgamma)
186186

187187
def iszero(a):
188-
return arith_unary_func(a, clib.af_iszero)
188+
return arith_unary_func(a, backend.get().af_iszero)
189189

190190
def isinf(a):
191-
return arith_unary_func(a, clib.af_isinf)
191+
return arith_unary_func(a, backend.get().af_isinf)
192192

193193
def isnan(a):
194-
return arith_unary_func(a, clib.af_isnan)
194+
return arith_unary_func(a, backend.get().af_isnan)

0 commit comments

Comments
 (0)