Skip to content

Commit 69da8bc

Browse files
committed
Moving functions to array.py to fix bugs and cleanup
1 parent 3df2188 commit 69da8bc

File tree

3 files changed

+120
-116
lines changed

3 files changed

+120
-116
lines changed

arrayfire/array.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,24 @@ def _ctype_to_lists(ctype_arr, dim, shape, offset=0):
113113
offset += shape[0]
114114
return res
115115

116+
def _slice_to_length(key, dim):
117+
tkey = [key.start, key.stop, key.step]
118+
119+
if tkey[0] is None:
120+
tkey[0] = 0
121+
elif tkey[0] < 0:
122+
tkey[0] = dim - tkey[0]
123+
124+
if tkey[1] is None:
125+
tkey[1] = dim
126+
elif tkey[1] < 0:
127+
tkey[1] = dim - tkey[1]
128+
129+
if tkey[2] is None:
130+
tkey[2] = 1
131+
132+
return int(((tkey[1] - tkey[0] - 1) / tkey[2]) + 1)
133+
116134
def _get_info(dims, buf_len):
117135
elements = 1
118136
numdims = len(dims)
@@ -132,6 +150,102 @@ def _get_info(dims, buf_len):
132150
return numdims, idims
133151

134152

153+
def _get_indices(key):
154+
155+
index_vec = Index * 4
156+
S = Index(slice(None))
157+
inds = index_vec(S, S, S, S)
158+
159+
if isinstance(key, tuple):
160+
n_idx = len(key)
161+
for n in range(n_idx):
162+
inds[n] = Index(key[n])
163+
else:
164+
inds[0] = Index(key)
165+
166+
return inds
167+
168+
def _get_assign_dims(key, idims):
169+
170+
dims = [1]*4
171+
172+
for n in range(len(idims)):
173+
dims[n] = idims[n]
174+
175+
if is_number(key):
176+
dims[0] = 1
177+
return dims
178+
elif isinstance(key, slice):
179+
dims[0] = _slice_to_length(key, idims[0])
180+
return dims
181+
elif isinstance(key, ParallelRange):
182+
dims[0] = _slice_to_length(key.S, idims[0])
183+
return dims
184+
elif isinstance(key, BaseArray):
185+
dims[0] = key.elements()
186+
return dims
187+
elif isinstance(key, tuple):
188+
n_inds = len(key)
189+
190+
for n in range(n_inds):
191+
if (is_number(key[n])):
192+
dims[n] = 1
193+
elif (isinstance(key[n], BaseArray)):
194+
dims[n] = key[n].elements()
195+
elif (isinstance(key[n], slice)):
196+
dims[n] = _slice_to_length(key[n], idims[n])
197+
elif (isinstance(key[n], ParallelRange)):
198+
dims[n] = _slice_to_length(key[n].S, idims[n])
199+
else:
200+
raise IndexError("Invalid type while assigning to arrayfire.array")
201+
202+
return dims
203+
else:
204+
raise IndexError("Invalid type while assigning to arrayfire.array")
205+
206+
207+
def transpose(a, conj=False):
208+
"""
209+
Perform the transpose on an input.
210+
211+
Parameters
212+
-----------
213+
a : af.Array
214+
Multi dimensional arrayfire array.
215+
216+
conj : optional: bool. default: False.
217+
Flag to specify if a complex conjugate needs to applied for complex inputs.
218+
219+
Returns
220+
--------
221+
out : af.Array
222+
Containing the tranpose of `a` for all batches.
223+
224+
"""
225+
out = Array()
226+
safe_call(backend.get().af_transpose(ct.pointer(out.arr), a.arr, conj))
227+
return out
228+
229+
def transpose_inplace(a, conj=False):
230+
"""
231+
Perform inplace transpose on an input.
232+
233+
Parameters
234+
-----------
235+
a : af.Array
236+
- Multi dimensional arrayfire array.
237+
- Contains transposed values on exit.
238+
239+
conj : optional: bool. default: False.
240+
Flag to specify if a complex conjugate needs to applied for complex inputs.
241+
242+
Note
243+
-------
244+
Input `a` needs to be a square matrix or a batch of square matrices.
245+
246+
"""
247+
safe_call(backend.get().af_transpose_inplace(a.arr, conj))
248+
135249
class Array(BaseArray):
136250

137251
"""
@@ -757,7 +871,7 @@ def __getitem__(self, key):
757871
try:
758872
out = Array()
759873
n_dims = self.numdims()
760-
inds = get_indices(key)
874+
inds = _get_indices(key)
761875

762876
safe_call(backend.get().af_index_gen(ct.pointer(out.arr),
763877
self.arr, ct.c_longlong(n_dims), ct.pointer(inds)))
@@ -778,13 +892,13 @@ def __setitem__(self, key, val):
778892
n_dims = self.numdims()
779893

780894
if (is_number(val)):
781-
tdims = get_assign_dims(key, self.dims())
895+
tdims = _get_assign_dims(key, self.dims())
782896
other_arr = constant_array(val, tdims[0], tdims[1], tdims[2], tdims[3], self.type())
783897
else:
784898
other_arr = val.arr
785899

786900
out_arr = ct.c_void_p(0)
787-
inds = get_indices(key)
901+
inds = _get_indices(key)
788902

789903
safe_call(backend.get().af_assign_gen(ct.pointer(out_arr),
790904
self.arr, ct.c_longlong(n_dims), ct.pointer(inds),

arrayfire/data.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,48 +16,6 @@
1616
from .array import *
1717
from .util import *
1818

19-
def transpose(a, conj=False):
20-
"""
21-
Perform the transpose on an input.
22-
23-
Parameters
24-
-----------
25-
a : af.Array
26-
Multi dimensional arrayfire array.
27-
28-
conj : optional: bool. default: False.
29-
Flag to specify if a complex conjugate needs to applied for complex inputs.
30-
31-
Returns
32-
--------
33-
out : af.Array
34-
Containing the tranpose of `a` for all batches.
35-
36-
"""
37-
out = Array()
38-
safe_call(backend.get().af_transpose(ct.pointer(out.arr), a.arr, conj))
39-
return out
40-
41-
def transpose_inplace(a, conj=False):
42-
"""
43-
Perform inplace transpose on an input.
44-
45-
Parameters
46-
-----------
47-
a : af.Array
48-
- Multi dimensional arrayfire array.
49-
- Contains transposed values on exit.
50-
51-
conj : optional: bool. default: False.
52-
Flag to specify if a complex conjugate needs to applied for complex inputs.
53-
54-
Note
55-
-------
56-
Input `a` needs to be a square matrix or a batch of square matrices.
57-
58-
"""
59-
safe_call(backend.get().af_transpose_inplace(a.arr, conj))
60-
6119
def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
6220
"""
6321
Create a multi dimensional array whose elements contain the same value.

arrayfire/index.py

Lines changed: 3 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,11 @@ def next(self):
6060
return self
6161

6262
def __next__(self):
63+
"""
64+
Function called by the iterator in Python 3
65+
"""
6366
return self.next()
6467

65-
def _slice_to_length(key, dim):
66-
tkey = [key.start, key.stop, key.step]
67-
68-
if tkey[0] is None:
69-
tkey[0] = 0
70-
elif tkey[0] < 0:
71-
tkey[0] = dim - tkey[0]
72-
73-
if tkey[1] is None:
74-
tkey[1] = dim
75-
elif tkey[1] < 0:
76-
tkey[1] = dim - tkey[1]
77-
78-
if tkey[2] is None:
79-
tkey[2] = 1
80-
81-
return int(((tkey[1] - tkey[0] - 1) / tkey[2]) + 1)
82-
8368
class _uidx(ct.Union):
8469
_fields_ = [("arr", ct.c_void_p),
8570
("seq", Seq)]
@@ -103,56 +88,3 @@ def __init__ (self, idx):
10388
self.isBatch = True
10489
else:
10590
self.idx.seq = Seq(idx)
106-
107-
def get_indices(key):
108-
109-
index_vec = Index * 4
110-
S = Index(slice(None))
111-
inds = index_vec(S, S, S, S)
112-
113-
if isinstance(key, tuple):
114-
n_idx = len(key)
115-
for n in range(n_idx):
116-
inds[n] = Index(key[n])
117-
else:
118-
inds[0] = Index(key)
119-
120-
return inds
121-
122-
def get_assign_dims(key, idims):
123-
124-
dims = [1]*4
125-
126-
for n in range(len(idims)):
127-
dims[n] = idims[n]
128-
129-
if is_number(key):
130-
dims[0] = 1
131-
return dims
132-
elif isinstance(key, slice):
133-
dims[0] = _slice_to_length(key, idims[0])
134-
return dims
135-
elif isinstance(key, ParallelRange):
136-
dims[0] = _slice_to_length(key.S, idims[0])
137-
return dims
138-
elif isinstance(key, BaseArray):
139-
dims[0] = key.elements()
140-
return dims
141-
elif isinstance(key, tuple):
142-
n_inds = len(key)
143-
144-
for n in range(n_inds):
145-
if (is_number(key[n])):
146-
dims[n] = 1
147-
elif (isinstance(key[n], BaseArray)):
148-
dims[n] = key[n].elements()
149-
elif (isinstance(key[n], slice)):
150-
dims[n] = _slice_to_length(key[n], idims[n])
151-
elif (isinstance(key[n], ParallelRange)):
152-
dims[n] = _slice_to_length(key[n].S, idims[n])
153-
else:
154-
raise IndexError("Invalid type while assigning to arrayfire.array")
155-
156-
return dims
157-
else:
158-
raise IndexError("Invalid type while assigning to arrayfire.array")

0 commit comments

Comments
 (0)