diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index a781fb74f92168..e38f05aeecda08 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1911,6 +1911,7 @@ def ctc_loss( blank: int = 0, reduction: _ReduceMode = 'mean', norm_by_times: bool = False, + zero_infinity: bool = False, ) -> Tensor: """ @@ -1927,6 +1928,7 @@ def ctc_loss( blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default: 0. reduction (str, optional): Indicate how to average the loss, the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default: ``'mean'``. norm_by_times (bool, optional): Whether to normalize the gradients by the number of time-step, which is also the sequence's length. There is no need to normalize the gradients if reduction mode is 'mean'. Default: False. + zero_infinity (bool, optional): If True, set infinite loss to zero. Default: False. Returns: Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is []. Data type is the same as ``log_probs``. @@ -2041,8 +2043,17 @@ def warpctc( loss_out = warpctc( log_probs, labels, blank, norm_by_times, input_lengths, label_lengths ) - loss_out = paddle.squeeze(loss_out, [-1]) + + if zero_infinity: + inf_mask = paddle.isinf(loss_out) + zero_value = paddle.zeros_like(loss_out) + loss_out = paddle.where( + condition=inf_mask, + x=zero_value, + y=loss_out, + ) + assert reduction in ['mean', 'sum', 'none'] if reduction == 'mean': loss_out = paddle.mean(loss_out / label_lengths.astype(loss_out.dtype)) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index b27ef6725d9a49..ef2e22ce60624a 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1352,6 +1352,7 @@ def forward( input_lengths: Tensor, label_lengths: Tensor, norm_by_times: bool = False, + zero_infinity: bool = False, ) -> Tensor: return paddle.nn.functional.ctc_loss( log_probs, @@ -1361,6 +1362,7 @@ def forward( self.blank, self.reduction, norm_by_times=norm_by_times, + zero_infinity=zero_infinity, ) diff --git a/test/legacy_test/test_warpctc_op.py b/test/legacy_test/test_warpctc_op.py index 982ccc21ff2b97..9c6f0dfa5b08a0 100644 --- a/test/legacy_test/test_warpctc_op.py +++ b/test/legacy_test/test_warpctc_op.py @@ -754,6 +754,94 @@ def test_functional_api(): test_functional_api() + def test_ctc_loss_zero_infinity(self): + max_time = 1 + batch = 1 + n_class = 8 + logits_np = np.random.randn(max_time, batch, n_class).astype("float32") + labels_np = np.random.randint(0, n_class - 1, (batch, 3)).astype( + "int32" + ) + input_len_np = np.array([1], dtype=np.int64) + label_len_np = np.array([3], dtype=np.int64) + + paddle.enable_static() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + logits = paddle.static.data( + name="logits_il", + shape=[max_time, batch, n_class], + dtype="float32", + ) + labels = paddle.static.data( + name="labels_il", shape=[batch, 3], dtype="int32" + ) + input_len = paddle.static.data( + name="input_len_il", shape=[batch], dtype="int64" + ) + label_len = paddle.static.data( + name="label_len_il", shape=[batch], dtype="int64" + ) + + loss = paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=labels, + input_lengths=input_len, + label_lengths=label_len, + reduction="none", + zero_infinity=True, + blank=n_class - 1, + ) + + exe = paddle.static.Executor() + loss_val = exe.run( + main_program, + feed={ + "logits_il": logits_np, + "labels_il": labels_np, + "input_len_il": input_len_np, + "label_len_il": label_len_np, + }, + fetch_list=[loss], + )[0] + + # illegal sample -> 0 + np.testing.assert_allclose(loss_val, [0.0], atol=1e-6) + + paddle.disable_static() + + def test_ctc_loss_zero_infinity_dygraph(self): + max_time = 1 + batch = 1 + n_class = 8 + + logits_np = np.random.randn(max_time, batch, n_class).astype("float32") + labels_np = np.random.randint(0, n_class - 1, (batch, 3)).astype( + "int32" + ) + input_len_np = np.array([1], dtype=np.int64) + label_len_np = np.array([3], dtype=np.int64) + + paddle.disable_static() + logits = paddle.to_tensor(logits_np) + labels = paddle.to_tensor(labels_np) + input_len = paddle.to_tensor(input_len_np) + label_len = paddle.to_tensor(label_len_np) + + loss = paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=labels, + input_lengths=input_len, + label_lengths=label_len, + reduction="none", + zero_infinity=True, + blank=n_class - 1, + ) + + np.testing.assert_allclose(loss.numpy(), [0.0], rtol=1e-6) + if __name__ == "__main__": unittest.main()