Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
eb6724f
feat: ctc_loss.zero_infinity
aztice Oct 11, 2025
48f9719
fix: code-style issue.
aztice Oct 11, 2025
e6d76e0
fix: code-style issue.
aztice Oct 11, 2025
e85ceb1
optimize: reduce calculation
aztice Oct 14, 2025
fbffc8b
fix: code-style issue.
aztice Oct 14, 2025
5179ca5
feat: test_warpctc_zero_infinity
aztice Oct 14, 2025
b37acdd
fix: code-style issue.
aztice Oct 14, 2025
948b74e
fix: code-style issue.
aztice Oct 14, 2025
86f5315
fix: zero_infinity for ctc_loss
aztice Oct 14, 2025
417dc59
fix: code-style issue.
aztice Oct 14, 2025
5d06635
fix: code-style issue.
aztice Oct 14, 2025
7d6f730
feat: ctcloss.zero_infinity
aztice Oct 17, 2025
df49adc
fix: docs issue.
aztice Oct 21, 2025
4d1dab9
fix: error
aztice Oct 21, 2025
34c51b0
fix: code-style issue.
aztice Oct 21, 2025
068f8a1
fix: errors
aztice Oct 21, 2025
480eb12
Fix loss output handling for zero_infinity case
aztice Oct 21, 2025
2e52ae9
fix: errors
aztice Oct 25, 2025
a7e9cc2
Replace zeros_like with full_like for loss output
aztice Oct 28, 2025
41279be
Refactor handling of infinite loss values
aztice Oct 28, 2025
d142222
Improve readability of loss_out assignment
aztice Oct 28, 2025
f5a8a75
Fix paddle.where parameters in loss function
aztice Oct 28, 2025
8573e56
Remove test_ctc_loss_zero_infinity function
aztice Oct 30, 2025
a4034d7
Add tests for CTC loss with zero_infinity
aztice Oct 30, 2025
2f0af5a
Format labels_np initialization for readability
aztice Oct 30, 2025
3f132ca
Refactor string literals and improve formatting
aztice Oct 30, 2025
6e0cc73
Improve code formatting in test_warpctc_op.py
aztice Oct 30, 2025
6c7c51a
Refactor static_graph context manager usage
aztice Oct 30, 2025
fe2834a
Use fixed seed value in static_graph context manager
aztice Oct 30, 2025
0ca8eb1
Refactor static graph usage in CTC loss test
aztice Oct 30, 2025
a49b6bc
Refactor CTC loss test to use program guard
aztice Oct 30, 2025
7b7decb
Remove unused import in test_warpctc_op.py
aztice Oct 30, 2025
4912b43
Enable and disable static mode in warpctc test
aztice Oct 30, 2025
f2b9f30
Update test/legacy_test/test_warpctc_op.py
aztice Oct 31, 2025
66a8efd
Replace assertTrue with assert_allclose in tests
aztice Oct 31, 2025
53b643e
Update test/legacy_test/test_warpctc_op.py
SigureMo Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,7 @@ def ctc_loss(
blank: int = 0,
reduction: _ReduceMode = 'mean',
norm_by_times: bool = False,
zero_infinity: bool = False,
) -> Tensor:
"""

Expand All @@ -1927,6 +1928,7 @@ def ctc_loss(
blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default: 0.
reduction (str, optional): Indicate how to average the loss, the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default: ``'mean'``.
norm_by_times (bool, optional): Whether to normalize the gradients by the number of time-step, which is also the sequence's length. There is no need to normalize the gradients if reduction mode is 'mean'. Default: False.
zero_infinity (bool, optional): If True, set infinite loss to zero. Default: False.

Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is []. Data type is the same as ``log_probs``.
Expand Down Expand Up @@ -2041,8 +2043,17 @@ def warpctc(
loss_out = warpctc(
log_probs, labels, blank, norm_by_times, input_lengths, label_lengths
)

loss_out = paddle.squeeze(loss_out, [-1])

if zero_infinity:
inf_mask = paddle.isinf(loss_out)
zero_value = paddle.zeros_like(loss_out)
loss_out = paddle.where(
condition=inf_mask,
x=zero_value,
y=loss_out,
)

assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths.astype(loss_out.dtype))
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@ def forward(
input_lengths: Tensor,
label_lengths: Tensor,
norm_by_times: bool = False,
zero_infinity: bool = False,
) -> Tensor:
return paddle.nn.functional.ctc_loss(
log_probs,
Expand All @@ -1361,6 +1362,7 @@ def forward(
self.blank,
self.reduction,
norm_by_times=norm_by_times,
zero_infinity=zero_infinity,
)


Expand Down
83 changes: 83 additions & 0 deletions test/legacy_test/test_warpctc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import sys
import unittest

Expand Down Expand Up @@ -754,6 +755,88 @@ def test_functional_api():

test_functional_api()

def test_ctc_loss_zero_infinity(self):
max_time = 1
batch = 1
n_class = 8
logits_np = np.random.randn(max_time, batch, n_class).astype("float32")
labels_np = np.random.randint(0, n_class - 1, (batch, 3)).astype(
"int32"
)
input_len_np = np.array([1], dtype=np.int64)
label_len_np = np.array([3], dtype=np.int64)

paddle.enable_static()
logits = paddle.static.data(
name="logits_il",
shape=[max_time, batch, n_class],
dtype="float32",
)
labels = paddle.static.data(
name="labels_il", shape=[batch, 3], dtype="int32"
)
input_len = paddle.static.data(
name="input_len_il", shape=[batch], dtype="int64"
)
label_len = paddle.static.data(
name="label_len_il", shape=[batch], dtype="int64"
)

loss = paddle.nn.functional.ctc_loss(
log_probs=logits,
labels=labels,
input_lengths=input_len,
label_lengths=label_len,
reduction="none",
zero_infinity=True,
blank=n_class - 1,
)

exe = paddle.static.Executor()
loss_val = exe.run(
feed={
"logits_il": logits_np,
"labels_il": labels_np,
"input_len_il": input_len_np,
"label_len_il": label_len_np,
},
fetch_list=[loss],
)[0]

# illegal sample -> 0
self.assertAlmostEqual(loss_val[0], 0.0, places=6)
paddle.disable_static()

def test_ctc_loss_zero_infinity_dygraph(self):
max_time = 1
batch = 1
n_class = 8

logits_np = np.random.randn(max_time, batch, n_class).astype("float32")
labels_np = np.random.randint(0, n_class - 1, (batch, 3)).astype(
"int32"
)
input_len_np = np.array([1], dtype=np.int64)
label_len_np = np.array([3], dtype=np.int64)

paddle.disable_static()
logits = paddle.to_tensor(logits_np)
labels = paddle.to_tensor(labels_np)
input_len = paddle.to_tensor(input_len_np)
label_len = paddle.to_tensor(label_len_np)

loss = paddle.nn.functional.ctc_loss(
log_probs=logits,
labels=labels,
input_lengths=input_len,
label_lengths=label_len,
reduction="none",
zero_infinity=True,
blank=n_class - 1,
)

np.testing.assert_allclose(loss.numpy()[0], 0.0, rtol=1e-6)


if __name__ == "__main__":
unittest.main()
Loading