Skip to content

Commit 298438d

Browse files
author
Haifeng Han
committed
fix issues of begin_mask, end_mask, ellipsis_mask in strided_slice
1 parent e306be1 commit 298438d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

python/slice.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ def setup(self, bottom, top):
7474
raise NotImplementedError("new_axis_mask is not implemented")
7575
self.shrink_axis = []
7676
for i, d in enumerate(shape):
77-
if self.begin_mask // 2**i:
77+
if self.begin_mask & 1 << i:
7878
begin[i] = 0
79-
if self.end_mask // 2**i:
79+
if self.end_mask & 1 << i:
8080
end[i] = shape[i]
81-
if self.ellipsis_mask // 2**i:
81+
if self.ellipsis_mask & 1 << i:
8282
begin[i] = 0
8383
end[i] = shape[i]
8484
strides[i] = 1
85-
if self.shrink_axis_mask // 2**i:
85+
if self.shrink_axis_mask & 1 << i:
8686
end[i] = begin[i] + 1
8787
strides[i] = 1
8888
self.shrink_axis.append(i)
@@ -100,7 +100,7 @@ def reshape(self, bottom, top):
100100
if self.strides[i] == 0:
101101
raise Exception("Strides should never be equal to 0!")
102102
else:
103-
num[i] = (abs(self.end[i]-self.begin[i])/abs(self.strides[i]))
103+
num[i] = abs(self.end[i]-self.begin[i])/abs(self.strides[i])
104104
for i in reversed(self.shrink_axis):
105105
num.pop(i)
106106
top[0].reshape(*num)

0 commit comments

Comments
 (0)