Skip to content

Commit 870ee50

Browse files
committed
FEAT: Adding broadcast to arrayfire
1 parent 3b8b928 commit 870ee50

File tree

5 files changed

+42
-5
lines changed

5 files changed

+42
-5
lines changed

arrayfire/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .features import *
2323
from .vision import *
2424
from .graphics import *
25+
from .broadcast import *
2526

2627
# do not export default modules as part of arrayfire
2728
del ct
@@ -34,6 +35,7 @@
3435
del seq
3536
del index
3637
del cell
38+
del bcast
3739

3840
#do not export internal functions
3941
del binary_func

arrayfire/arith.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .library import *
1111
from .array import *
12+
from .broadcast import *
1213

1314
def arith_binary_func(lhs, rhs, c_func):
1415
out = array()
@@ -20,21 +21,21 @@ def arith_binary_func(lhs, rhs, c_func):
2021
TypeError("Atleast one input needs to be of type arrayfire.array")
2122

2223
elif (is_left_array and is_right_array):
23-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, False))
24+
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, bcast.get()))
2425

2526
elif (is_number(rhs)):
2627
ldims = dim4_tuple(lhs.dims())
2728
lty = lhs.type()
2829
other = array()
2930
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
30-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
31+
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
3132

3233
else:
3334
rdims = dim4_tuple(rhs.dims())
3435
rty = rhs.type()
3536
other = array()
3637
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
37-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
38+
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
3839

3940
return out
4041

arrayfire/array.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import inspect
1111
from .library import *
1212
from .util import *
13+
from .broadcast import *
1314

1415
def create_array(buf, numdims, idims, dtype):
1516
out_arr = ct.c_longlong(0)
@@ -63,7 +64,7 @@ def binary_func(lhs, rhs, c_func):
6364
elif not isinstance(rhs, array):
6465
raise TypeError("Invalid parameter to binary function")
6566

66-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
67+
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
6768

6869
return out
6970

@@ -79,7 +80,7 @@ def binary_funcr(lhs, rhs, c_func):
7980
elif not isinstance(lhs, array):
8081
raise TypeError("Invalid parameter to binary function")
8182

82-
c_func(ct.pointer(out.arr), other.arr, rhs.arr, False)
83+
c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get())
8384

8485
return out
8586

arrayfire/broadcast.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
11+
class bcast(object):
12+
_flag = False
13+
def get():
14+
return bcast._flag
15+
16+
def set(flag):
17+
bcast._flag = flag
18+
19+
def toggle():
20+
bcast._flag ^= True
21+
22+
def broadcast(func, *args):
23+
bcast.toggle()
24+
res = func(*args)
25+
bcast.toggle()
26+
return res

tests/simple_arith.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,10 @@
188188
af.display(af.iszero(a))
189189
af.display(af.isinf(a/b))
190190
af.display(af.isnan(a/a))
191+
192+
a = af.randu(5, 1)
193+
b = af.randu(5, 5)
194+
c = af.broadcast(lambda x,y: x+y, a, b)
195+
af.display(a)
196+
af.display(b)
197+
af.display(c)

0 commit comments

Comments
 (0)