Skip to content

Commit 3a0c550

Browse files
authored
Fix dropout static when axis != None (#37223) (#37589)
* fix dropout static when axis != None * update dropout test * add dropout test * fix test * Update test_dropout_op.py * Update test_dropout_op.py * fix testcase * fix testcase * Update test_dropout_op.py * fix testcase * fix testcase * optimize perf * add new test * fix testcase
1 parent 7d9c669 commit 3a0c550

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

python/paddle/fluid/tests/unittests/test_dropout_op.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def setUp(self):
333333

334334
def check_static_result(self, place):
335335
with fluid.program_guard(fluid.Program(), fluid.Program()):
336-
input = fluid.data(name="input", shape=[40, 40], dtype="float32")
336+
input = fluid.data(name="input", shape=[-1, -1], dtype="float32")
337337
res1 = paddle.nn.functional.dropout(x=input, p=0., training=False)
338338
res2 = paddle.nn.functional.dropout(
339339
x=input, p=0., axis=0, training=True, mode='upscale_in_train')
@@ -380,7 +380,10 @@ def check_static_result(self, place):
380380
training=False,
381381
mode='upscale_in_train')
382382

383-
in_np = np.random.random([40, 40]).astype("float32")
383+
res13 = paddle.nn.functional.dropout(
384+
x=input, p=0.7, axis=1, training=True, mode='upscale_in_train')
385+
386+
in_np = np.ones([40, 40]).astype("float32")
384387
res_np = in_np
385388
res_np2 = np.zeros_like(in_np)
386389

@@ -398,6 +401,9 @@ def check_static_result(self, place):
398401
feed={"input": in_np},
399402
fetch_list=[res10])
400403
self.assertTrue(np.allclose(fetches2[0], res_np2))
404+
fetches3 = exe.run(fluid.default_main_program(),
405+
feed={"input": in_np},
406+
fetch_list=[res13])
401407

402408
def test_static(self):
403409
for place in self.places:
@@ -471,6 +477,12 @@ def test_dygraph(self):
471477
axis=(0, 1),
472478
training=False,
473479
mode='upscale_in_train')
480+
res13 = paddle.nn.functional.dropout(
481+
x=input,
482+
p=0.5,
483+
axis=1,
484+
training=True,
485+
mode='upscale_in_train')
474486

475487
res_list = [
476488
res1, res2, res3, res4, res5, res6, res7, res8, res9, res11,

python/paddle/nn/functional/common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,8 @@ def get_attrs(prog, dropout_prob, is_test, seed):
939939

940940
#get mask shape
941941
input_shape = x.shape
942+
if not in_dygraph_mode():
943+
input_shape_tensor = paddle.shape(x)
942944
drop_axes = [axis] if isinstance(axis, int) else list(axis)
943945
if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1:
944946
raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \
@@ -948,8 +950,12 @@ def get_attrs(prog, dropout_prob, is_test, seed):
948950
"length of axis should not be greater than dimensions of x:{}, but get length of axis: {}".
949951
format(len(input_shape), len(drop_axes)))
950952
mask_shape = [1] * len(input_shape)
951-
for i in drop_axes:
952-
mask_shape[i] = input_shape[i]
953+
if not in_dygraph_mode():
954+
for i in drop_axes:
955+
mask_shape[i] = input_shape_tensor[i]
956+
else:
957+
for i in drop_axes:
958+
mask_shape[i] = input_shape[i]
953959

954960
#get mask
955961
random_tensor = paddle.uniform(

0 commit comments

Comments
 (0)