Skip to content

Commit e014950

Browse files
author
wopeizl
authored
add slice support for dim < 0 (#16494)
* add slice support for dim < 0 test=develop
1 parent 8f7b588 commit e014950

File tree

2 files changed

+39
-38
lines changed

2 files changed

+39
-38
lines changed

python/paddle/fluid/framework.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -789,13 +789,24 @@ def __getitem__(self, item):
789789
if isinstance(item, tuple):
790790
if len(item) > len(self.shape):
791791
raise IndexError("Too many indexes")
792+
fixedSize = True
793+
for i in range(len(self.shape)):
794+
if self.shape[i] == -1:
795+
fixedSize = False
796+
break
797+
792798
newitem = self._reconstructSliceinfo(item) or item
793-
check, info = self._detectContinuesSlice(newitem)
794-
if check:
795-
starts = info[0]
796-
ends = info[1]
797-
axes = [i for i in range(len(starts))]
798-
return self._sliceVar(axes, starts, ends)
799+
if fixedSize:
800+
check, info = self._detectContinuesSlice(newitem)
801+
if check and fixedSize:
802+
starts = info[0]
803+
ends = info[1]
804+
axes = [i for i in range(len(starts))]
805+
return self._sliceVar(axes, starts, ends)
806+
else:
807+
new_var = self
808+
for index, o in enumerate(newitem):
809+
new_var = new_var._sliceAndConcatVar(o, index)
799810
else:
800811
new_var = self
801812
for index, o in enumerate(newitem):

python/paddle/fluid/tests/unittests/test_variable.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_step_scopes(self):
6161
name='step_scopes', type=core.VarDesc.VarType.STEP_SCOPES)
6262
self.assertEqual(core.VarDesc.VarType.STEP_SCOPES, var.type)
6363

64-
def _test_slice(self):
64+
def _test_slice(self, place):
6565
b = default_main_program().current_block()
6666
w = b.create_var(dtype="float64", shape=[784, 100, 100], lod_level=0)
6767

@@ -83,7 +83,6 @@ def _test_slice(self):
8383

8484
self.assertEqual(0, nw.lod_level)
8585

86-
place = fluid.CPUPlace()
8786
main = fluid.Program()
8887
with fluid.program_guard(main):
8988
exe = fluid.Executor(place)
@@ -100,10 +99,23 @@ def _test_slice(self):
10099
var6 = var[1, 1:, 1:]
101100
var7 = var[1, ..., 1:]
102101
var8 = var[1, ...]
102+
var_reshape = fluid.layers.reshape(var, [3, -1, 3])
103+
var9 = var_reshape[1, ..., 2]
104+
var10 = var_reshape[:, :, -1]
105+
106+
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
107+
y = fluid.layers.fc(input=x, size=1, act=None)
108+
var11 = y[:, 0]
109+
feeder = fluid.DataFeeder(place=place, feed_list=[x])
110+
data = []
111+
data.append((np.random.randint(10, size=[13]).astype('float32')))
112+
exe.run(fluid.default_startup_program())
113+
103114
local_out = exe.run(main,
115+
feed=feeder.feed([data]),
104116
fetch_list=[
105117
var, var1, var2, var3, var4, var5, var6,
106-
var7, var8
118+
var7, var8, var9, var10, var11
107119
])
108120

109121
self.assertTrue((np.array(local_out[1]) == np.array(tensor_array[
@@ -122,38 +134,16 @@ def _test_slice(self):
122134
1, ..., 1:])).all())
123135
self.assertTrue((np.array(local_out[8]) == np.array(tensor_array[
124136
1, ...])).all())
137+
self.assertEqual(local_out[9].shape, (1, 3, 1))
138+
self.assertEqual(local_out[10].shape, (3, 3, 1))
139+
self.assertEqual(local_out[11].shape, (1, 1))
125140

126141
def test_slice(self):
127-
self._test_slice()
128-
129-
130-
class TestVariableImperative(unittest.TestCase):
131-
def _test_slice(self):
132-
b = default_main_program().current_block()
133-
w = b.create_var(dtype="float64", shape=[784, 100, 100], lod_level=0)
134-
135-
for i in range(3):
136-
nw = w[i]
137-
self.assertEqual([1, 100, 100], nw.shape)
138-
139-
nw = w[:]
140-
self.assertEqual([784, 100, 100], nw.shape)
141-
142-
nw = w[:, :, :]
143-
self.assertEqual([784, 100, 100], nw.shape)
144-
145-
nw = w[::2, ::2, :]
146-
self.assertEqual([392, 50, 100], nw.shape)
147-
148-
nw = w[::-2, ::-2, :]
149-
self.assertEqual([392, 50, 100], nw.shape)
150-
151-
nw = w[0::-2, 0::-2, :]
152-
self.assertEqual([1, 1, 100], nw.shape)
142+
place = fluid.CPUPlace()
143+
self._test_slice(place)
153144

154-
def test_slice(self):
155-
with fluid.dygraph.guard():
156-
self._test_slice()
145+
if core.is_compiled_with_cuda():
146+
self._test_slice(core.CUDAPlace(0))
157147

158148

159149
if __name__ == '__main__':

0 commit comments

Comments
 (0)