Skip to content

Commit 6b543d2

Browse files
committed
Making sure indexing operation is not dropping dimensions
- Added relevant test
1 parent 5581632 commit 6b543d2

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

arrayfire/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def __getitem__(self, key):
426426
try:
427427
out = Array()
428428
n_dims = self.numdims()
429-
inds = get_indices(key, n_dims)
429+
inds = get_indices(key)
430430

431431
safe_call(clib.af_index_gen(ct.pointer(out.arr),
432432
self.arr, ct.c_longlong(n_dims), ct.pointer(inds)))
@@ -446,7 +446,7 @@ def __setitem__(self, key, val):
446446
other_arr = val.arr
447447

448448
out_arr = ct.c_void_p(0)
449-
inds = get_indices(key, n_dims)
449+
inds = get_indices(key)
450450

451451
safe_call(clib.af_assign_gen(ct.pointer(out_arr),
452452
self.arr, ct.c_longlong(n_dims), ct.pointer(inds),

arrayfire/index.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,11 @@ def __init__ (self, idx):
104104
else:
105105
self.idx.seq = Seq(idx)
106106

107-
def get_indices(key, n_dims):
107+
def get_indices(key):
108108

109-
index_vec = Index * n_dims
110-
inds = index_vec()
111-
112-
for n in range(n_dims):
113-
inds[n] = Index(slice(None))
109+
index_vec = Index * 4
110+
S = Index(slice(None))
111+
inds = index_vec(S, S, S, S)
114112

115113
if isinstance(key, tuple):
116114
n_idx = len(key)
@@ -143,9 +141,6 @@ def get_assign_dims(key, idims):
143141
elif isinstance(key, tuple):
144142
n_inds = len(key)
145143

146-
if (n_inds > len(idims)):
147-
raise IndexError("Number of indices greater than array dimensions")
148-
149144
for n in range(n_inds):
150145
if (is_number(key[n])):
151146
dims[n] = 1

tests/simple_index.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,11 @@
5858
for ii in ParallelRange(2,5):
5959
b[ii] = 2
6060
af.display(b)
61+
62+
a = af.randu(3,2)
63+
rows = af.constant(0, 1, dtype=af.s32)
64+
b = a[:,rows]
65+
af.display(b)
66+
for r in rows:
67+
af.display(r)
68+
af.display(b[:,r])

0 commit comments

Comments
 (0)