Skip to content

Commit e66f299

Browse files
committed
Adding documentation for bcast.py
- Renamed broadcast.py to bcast.py - Renamed variable bcast to bacst_var
1 parent cef85ee commit e66f299

File tree

6 files changed

+109
-48
lines changed

6 files changed

+109
-48
lines changed

arrayfire/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from .features import *
5151
from .vision import *
5252
from .graphics import *
53-
from .broadcast import *
53+
from .bcast import *
5454
from .index import *
5555

5656
# do not export default modules as part of arrayfire
@@ -60,6 +60,6 @@
6060
del os
6161

6262
#do not export internal functions
63-
del bcast
63+
del bcast_var
6464
del is_number
6565
del safe_call

arrayfire/arith.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .library import *
1515
from .array import *
16-
from .broadcast import *
16+
from .bcast import *
1717

1818
def _arith_binary_func(lhs, rhs, c_func):
1919
out = Array()
@@ -25,21 +25,21 @@ def _arith_binary_func(lhs, rhs, c_func):
2525
raise TypeError("Atleast one input needs to be of type arrayfire.array")
2626

2727
elif (is_left_array and is_right_array):
28-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, bcast.get()))
28+
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, bcast_var.get()))
2929

3030
elif (is_number(rhs)):
3131
ldims = dim4_to_tuple(lhs.dims())
3232
rty = implicit_dtype(rhs, lhs.type())
3333
other = Array()
3434
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
35-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
35+
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast_var.get()))
3636

3737
else:
3838
rdims = dim4_to_tuple(rhs.dims())
3939
lty = implicit_dtype(lhs, rhs.type())
4040
other = Array()
4141
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
42-
safe_call(c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get()))
42+
safe_call(c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast_var.get()))
4343

4444
return out
4545

arrayfire/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import inspect
1515
from .library import *
1616
from .util import *
17-
from .broadcast import *
17+
from .bcast import *
1818
from .base import *
1919
from .index import *
2020

@@ -82,7 +82,7 @@ def _binary_func(lhs, rhs, c_func):
8282
elif not isinstance(rhs, Array):
8383
raise TypeError("Invalid parameter to binary function")
8484

85-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
85+
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast_var.get()))
8686

8787
return out
8888

@@ -98,7 +98,7 @@ def _binary_funcr(lhs, rhs, c_func):
9898
elif not isinstance(lhs, Array):
9999
raise TypeError("Invalid parameter to binary function")
100100

101-
c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get())
101+
c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast_var.get())
102102

103103
return out
104104

arrayfire/bcast.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
"""
11+
Function to perform broadcasting operations.
12+
"""
13+
14+
class _bcast(object):
15+
_flag = False
16+
def get(self):
17+
return _bcast._flag
18+
19+
def set(self, flag):
20+
_bcast._flag = flag
21+
22+
def toggle(self):
23+
_bcast._flag ^= True
24+
25+
bcast_var = _bcast()
26+
27+
def broadcast(func, *args):
28+
"""
29+
Function to perform broadcast operations.
30+
31+
This function can be used directly or as an annotation in the following manner.
32+
33+
Example
34+
-------
35+
36+
Using broadcast as an annotation
37+
38+
>>> import arrayfire as af
39+
>>> @af.broadcast
40+
... def add(a, b):
41+
... return a + b
42+
...
43+
>>> a = af.randu(2,3)
44+
>>> b = af.randu(2,1) # b is a different size
45+
>>> # Trying to add arrays of different sizes raises an exceptions
46+
>>> c = add(a, b) # This call does not raise an exception because of the annotation
47+
>>> af.display(a)
48+
[2 3 1 1]
49+
0.4107 0.9518 0.4198
50+
0.8224 0.1794 0.0081
51+
52+
>>> af.display(b)
53+
[2 1 1 1]
54+
0.7269
55+
0.7104
56+
57+
>>> af.display(c)
58+
[2 3 1 1]
59+
1.1377 1.6787 1.1467
60+
1.5328 0.8898 0.7185
61+
62+
Using broadcast as function
63+
64+
>>> import arrayfire as af
65+
>>> add = lambda a,b: a + b
66+
>>> a = af.randu(2,3)
67+
>>> b = af.randu(2,1) # b is a different size
68+
>>> # Trying to add arrays of different sizes raises an exceptions
69+
>>> c = af.broadcast(add, a, b) # This call does not raise an exception
70+
>>> af.display(a)
71+
[2 3 1 1]
72+
0.4107 0.9518 0.4198
73+
0.8224 0.1794 0.0081
74+
75+
>>> af.display(b)
76+
[2 1 1 1]
77+
0.7269
78+
0.7104
79+
80+
>>> af.display(c)
81+
[2 3 1 1]
82+
1.1377 1.6787 1.1467
83+
1.5328 0.8898 0.7185
84+
85+
"""
86+
87+
def wrapper(*func_args):
88+
bcast_var.toggle()
89+
res = func(*func_args)
90+
bcast_var.toggle()
91+
return res
92+
93+
if len(args) == 0:
94+
return wrapper
95+
else:
96+
return wrapper(*args)

arrayfire/broadcast.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

arrayfire/index.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .library import *
1010
from .util import *
1111
from .base import *
12-
from .broadcast import *
12+
from .bcast import *
1313
import math
1414

1515
class Seq(ct.Structure):
@@ -52,11 +52,11 @@ def __iter__(self):
5252
return self
5353

5454
def next(self):
55-
if bcast.get() is True:
56-
bcast.toggle()
55+
if bcast_var.get() is True:
56+
bcast_var.toggle()
5757
raise StopIteration
5858
else:
59-
bcast.toggle()
59+
bcast_var.toggle()
6060
return self
6161

6262
def __next__(self):

0 commit comments

Comments
 (0)