Skip to content

Commit dbd453e

Browse files
add paddle nn.functional.dropout1d api (#74444)
1 parent 607dd38 commit dbd453e

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed

python/paddle/nn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
class_center_sample,
6060
cosine_similarity,
6161
dropout,
62+
dropout1d,
6263
dropout2d,
6364
dropout3d,
6465
feature_alpha_dropout,
@@ -216,6 +217,7 @@
216217
'gumbel_softmax',
217218
'sequence_mask',
218219
'dropout',
220+
'dropout1d',
219221
'dropout2d',
220222
'dropout3d',
221223
'alpha_dropout',

python/paddle/nn/functional/common.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import math
18+
import warnings
1819
from typing import TYPE_CHECKING, Literal
1920

2021
import numpy
@@ -1427,6 +1428,74 @@ def get_attrs(prog, dropout_prob, is_test, seed):
14271428
return ret
14281429

14291430

1431+
def dropout1d(
1432+
input: paddle.Tensor,
1433+
p: float = 0.5,
1434+
training: bool = True,
1435+
inplace: bool = False,
1436+
) -> paddle.Tensor:
1437+
"""
1438+
Randomly zero out entire 1D channels (feature maps) during training.
1439+
1440+
Args:
1441+
input: Input tensor of shape [C, L] (2D) or [N, C, L] (3D)
1442+
p: Probability of a channel being zeroed. Default: 0.5
1443+
training: If False, returns input unchanged. Default: True
1444+
inplace: If True, modifies input tensor in-place. Default: False
1445+
WARNING: Currently not implemented (will behave as False).
1446+
TODO: Implement in-place operation in future versions.
1447+
Default: False
1448+
1449+
Returns:
1450+
Tensor with the same shape as input, where entire channels are zeroed with probability p
1451+
1452+
Examples:
1453+
.. code-block:: python
1454+
1455+
>>> import paddle
1456+
1457+
# Case 1: 3D input (batched)
1458+
>>> x = paddle.randn([2, 3, 10]) # [N, C, L]
1459+
>>> y_train = paddle.nn.functional.dropout1d(x, p=0.2) # Training mode
1460+
>>> y_test = paddle.nn.functional.dropout1d(x, p=0.2, training=False) # Test mode
1461+
>>> print("Original first channel:", x[0, 0, :])
1462+
>>> print("Train output (may be zeroed):", y_train[0, 0, :])
1463+
>>> print("Test output (always unchanged):", y_test[0, 0, :])
1464+
1465+
# Case 2: 2D input (single sample)
1466+
>>> x = paddle.randn([3, 8]) # [C, L]
1467+
>>> y = paddle.nn.functional.dropout1d(x, p=0.5)
1468+
>>> print("Input shape:", x.shape)
1469+
>>> print("Output shape:", y.shape)
1470+
>>> print("Zeroed channels count:", paddle.sum(y == 0).item())
1471+
"""
1472+
if p < 0 or p > 1:
1473+
raise ValueError(f"dropout probability must be in [0, 1], got {p}")
1474+
1475+
ndim = input.ndim
1476+
if ndim not in [2, 3]:
1477+
raise RuntimeError(f"dropout1d expects 2D or 3D input, got {ndim}D")
1478+
1479+
if inplace:
1480+
warnings.warn(
1481+
"inplace=True is currently not supported in dropout1d and will be ignored. "
1482+
"This parameter is reserved for future implementation."
1483+
)
1484+
# TODO: Implement actual in-place operation when supported by dropout
1485+
1486+
need_squeeze = ndim == 2
1487+
if need_squeeze:
1488+
input = input.unsqueeze(0) # [C, L] -> [1, C, L]
1489+
1490+
# Apply dropout along channel dimension
1491+
result = dropout(input, p=p, axis=1, training=training)
1492+
1493+
if need_squeeze:
1494+
result = result.squeeze(0) # [1, C, L] -> [C, L]
1495+
1496+
return result
1497+
1498+
14301499
def dropout2d(
14311500
x: Tensor,
14321501
p: float = 0.5,

test/legacy_test/test_dropout_op.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,121 @@ def test_dygraph(self):
919919
)
920920

921921

922+
class TestDropout1DFAPI(unittest.TestCase):
923+
def setUp(self):
924+
np.random.seed(123)
925+
self.places = get_places()
926+
927+
def check_static_result(
928+
self, place, input_name, input_shape, training=False, p=0.0
929+
):
930+
paddle.enable_static()
931+
main_prog = paddle.static.Program()
932+
startup_prog = paddle.static.Program()
933+
with paddle.static.program_guard(main_prog, startup_prog):
934+
input_var = paddle.static.data(
935+
name=input_name, shape=input_shape, dtype="float32"
936+
)
937+
res = paddle.nn.functional.dropout1d(
938+
input=input_var, p=p, training=training
939+
)
940+
in_np = np.random.random(input_shape).astype("float32")
941+
exe = base.Executor(place)
942+
fetches = exe.run(
943+
main_prog,
944+
feed={input_name: in_np},
945+
fetch_list=[res],
946+
)
947+
948+
np.testing.assert_allclose(fetches[0], in_np, rtol=1e-05)
949+
950+
def test_static(self):
951+
for place in self.places:
952+
self.check_static_result(
953+
place=place,
954+
input_name="input_2d",
955+
input_shape=[3, 4],
956+
training=False,
957+
p=0.0,
958+
)
959+
960+
self.check_static_result(
961+
place=place,
962+
input_name="input_3d",
963+
input_shape=[2, 3, 4],
964+
training=False,
965+
p=0.0,
966+
)
967+
968+
self.check_static_result(
969+
place=place,
970+
input_name="input_2d_1",
971+
input_shape=[3, 4],
972+
training=False,
973+
p=1.0,
974+
)
975+
976+
self.check_static_result(
977+
place=place,
978+
input_name="input_3d_1",
979+
input_shape=[2, 3, 4],
980+
training=False,
981+
p=1.0,
982+
)
983+
984+
def test_dygraph(self):
985+
for place in self.places:
986+
with base.dygraph.guard(place):
987+
# Test 2D input
988+
in_np_2d = np.random.random([3, 4]).astype("float32")
989+
input_2d = paddle.to_tensor(in_np_2d)
990+
res1 = paddle.nn.functional.dropout1d(
991+
input=input_2d, p=0.0, training=False
992+
)
993+
np.testing.assert_allclose(res1.numpy(), in_np_2d, rtol=1e-05)
994+
995+
# Test 3D input
996+
in_np_3d = np.random.random([2, 3, 4]).astype("float32")
997+
input_3d = paddle.to_tensor(in_np_3d)
998+
res2 = paddle.nn.functional.dropout1d(
999+
input=input_3d, p=0.0, training=False
1000+
)
1001+
np.testing.assert_allclose(res2.numpy(), in_np_3d, rtol=1e-05)
1002+
1003+
1004+
class TestDropout1DFAPIError(unittest.TestCase):
1005+
def test_errors(self):
1006+
paddle.enable_static()
1007+
main_prog = paddle.static.Program()
1008+
startup_prog = paddle.static.Program()
1009+
with paddle.static.program_guard(main_prog, startup_prog):
1010+
1011+
def test_xdim_1d():
1012+
# dimensions of x should be 2 or 3
1013+
x = paddle.static.data(name='x1', shape=[4], dtype="float32")
1014+
paddle.nn.functional.dropout1d(x)
1015+
1016+
self.assertRaises(RuntimeError, test_xdim_1d)
1017+
1018+
def test_xdim_4d():
1019+
# dimensions of x should be 2 or 3
1020+
x = paddle.static.data(
1021+
name='x2', shape=[2, 3, 4, 5], dtype="float32"
1022+
)
1023+
paddle.nn.functional.dropout1d(x)
1024+
1025+
self.assertRaises(RuntimeError, test_xdim_4d)
1026+
1027+
def test_prob_range():
1028+
# p should be in [0, 1]
1029+
x = paddle.static.data(
1030+
name='x3', shape=[2, 3, 4], dtype="float32"
1031+
)
1032+
paddle.nn.functional.dropout1d(x, p=1.5)
1033+
1034+
self.assertRaises(ValueError, test_prob_range)
1035+
1036+
9221037
class TestDropout2DFAPI(unittest.TestCase):
9231038
def setUp(self):
9241039
np.random.seed(123)
@@ -1404,6 +1519,12 @@ def test_p_tensor(self):
14041519
np.testing.assert_array_equal(static_res, dygraph_res)
14051520

14061521

1522+
class TestDropOut1DWithProbTensor(TestDropOutWithProbTensor):
1523+
def init_info(self):
1524+
self.shape = [2, 3, 4]
1525+
self.api = paddle.nn.functional.dropout1d
1526+
1527+
14071528
class TestDropOut2DWithProbTensor(TestDropOutWithProbTensor):
14081529
def init_info(self):
14091530
self.shape = [2, 3, 10, 10]

0 commit comments

Comments
 (0)