Skip to content

Commit 2118868

Browse files
authored
fix unsqueeze in dygraph (#27107) (#27151)
* fix unsqueeze in dygraph * add test * add test
1 parent 7a63960 commit 2118868

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6306,6 +6306,15 @@ def unsqueeze(input, axes, name=None):
63066306

63076307
"""
63086308
if in_dygraph_mode():
6309+
if isinstance(axes, int):
6310+
axes = [axes]
6311+
elif isinstance(axes, Variable):
6312+
axes = [axes.numpy().item(0)]
6313+
elif isinstance(axes, (list, tuple)):
6314+
axes = [
6315+
item.numpy().item(0) if isinstance(item, Variable) else item
6316+
for item in axes
6317+
]
63096318
out, _ = core.ops.unsqueeze2(input, 'axes', axes)
63106319
return out
63116320

python/paddle/fluid/tests/unittests/test_unsqueeze_op.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,29 +134,60 @@ def test_out(self):
134134
result1, = exe.run(feed={"data1": input,
135135
"data2": input2},
136136
fetch_list=[result_squeeze])
137-
self.assertTrue(np.allclose(input1, result1))
137+
self.assertTrue(np.array_equal(input1, result1))
138+
self.assertEqual(input1.shape, result1.shape)
138139

139140

140141
class API_TestDyUnsqueeze(unittest.TestCase):
141142
def test_out(self):
142143
with fluid.dygraph.guard():
143144
input_1 = np.random.random([5, 1, 10]).astype("int32")
144-
input1 = np.squeeze(input_1, axis=1)
145+
input1 = np.expand_dims(input_1, axis=1)
145146
input = fluid.dygraph.to_variable(input_1)
146147
output = paddle.unsqueeze(input, axis=[1])
147148
out_np = output.numpy()
148-
self.assertTrue(np.allclose(input1, out_np))
149+
self.assertTrue(np.array_equal(input1, out_np))
150+
self.assertEqual(input1.shape, out_np.shape)
149151

150152

151153
class API_TestDyUnsqueeze2(unittest.TestCase):
152154
def test_out(self):
153155
with fluid.dygraph.guard():
154-
input_1 = np.random.random([5, 1, 10]).astype("int32")
155-
input1 = np.squeeze(input_1, axis=1)
156-
input = fluid.dygraph.to_variable(input_1)
156+
input1 = np.random.random([5, 10]).astype("int32")
157+
out1 = np.expand_dims(input1, axis=1)
158+
input = fluid.dygraph.to_variable(input1)
157159
output = paddle.unsqueeze(input, axis=1)
158160
out_np = output.numpy()
159-
self.assertTrue(np.allclose(input1, out_np))
161+
self.assertTrue(np.array_equal(out1, out_np))
162+
self.assertEqual(out1.shape, out_np.shape)
163+
164+
165+
class API_TestDyUnsqueezeAxisTensor(unittest.TestCase):
166+
def test_out(self):
167+
with fluid.dygraph.guard():
168+
input1 = np.random.random([5, 10]).astype("int32")
169+
out1 = np.expand_dims(input1, axis=1)
170+
input = fluid.dygraph.to_variable(input1)
171+
output = paddle.unsqueeze(input, axis=paddle.to_tensor([1]))
172+
out_np = output.numpy()
173+
self.assertTrue(np.array_equal(out1, out_np))
174+
self.assertEqual(out1.shape, out_np.shape)
175+
176+
177+
class API_TestDyUnsqueezeAxisTensorList(unittest.TestCase):
178+
def test_out(self):
179+
with fluid.dygraph.guard():
180+
input1 = np.random.random([5, 10]).astype("int32")
181+
# Actually, expand_dims supports tuple since version 1.18.0
182+
out1 = np.expand_dims(input1, axis=1)
183+
out1 = np.expand_dims(out1, axis=2)
184+
input = fluid.dygraph.to_variable(input1)
185+
output = paddle.unsqueeze(
186+
fluid.dygraph.to_variable(input1),
187+
axis=[paddle.to_tensor([1]), paddle.to_tensor([2])])
188+
out_np = output.numpy()
189+
self.assertTrue(np.array_equal(out1, out_np))
190+
self.assertEqual(out1.shape, out_np.shape)
160191

161192

162193
if __name__ == "__main__":

python/paddle/tensor/manipulation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,6 @@ def unsqueeze(x, axis, name=None):
746746
print(out3.shape) # [1, 1, 1, 5, 10]
747747
748748
"""
749-
if isinstance(axis, int):
750-
axis = [axis]
751749

752750
return layers.unsqueeze(x, axis, name)
753751

0 commit comments

Comments
 (0)